Last active
November 15, 2015 04:14
-
-
Save plaidfinch/1b4e227e476353e775fe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE ConstraintKinds #-} | |
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE StandaloneDeriving #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
module TeachingGADTs where | |
import Control.Applicative | |
import Control.Monad | |
import Unsafe.Coerce | |
-- Motivation: how would you write a "typed" expression language and its evaluator? | |
-- We use GADT *syntax* here but we don't use GADT *semantic* features | |
-- introducing GADT syntax: here's how you'd do Maybe | |
data Maybe' a where | |
Just' :: a -> Maybe' a | |
Nothing' :: Maybe' a | |
data Exp where | |
PlusExp :: Exp -> Exp -> Exp | |
LessExp :: Exp -> Exp -> Exp | |
IntExp :: Int -> Exp | |
BoolExp :: Bool -> Exp | |
IfTeExp :: Exp -> Exp -> Exp -> Exp | |
-- type-checking happens at run-time during evaluation, and we can easily get it wrong | |
eval :: Exp -> Maybe (Either Int Bool) -- we don't get compile-time type-checking for our DSL | |
eval (PlusExp e1 e2) = do Left i1 <- eval e1 | |
Left i2 <- eval e2 | |
pure . Left $ i1 + i2 | |
eval (LessExp e1 e2) = do Left i1 <- eval e1 | |
Left i2 <- eval e2 | |
pure . Right $ i1 < i2 | |
eval (IntExp i) = Just $ Left i | |
eval (BoolExp b) = Just $ Right b | |
eval (IfTeExp e1 e2 e3) = do Right b <- eval e1 | |
v2 <- eval e2 | |
v3 <- eval e3 | |
guard (bothLeft v2 v3 || bothRight v2 v3) -- this line would be easy to forget! | |
pure $ if b then v2 else v3 | |
where bothLeft (Left _) (Left _) = True | |
bothLeft _ _ = False | |
bothRight (Right _) (Right _) = True | |
bothRight _ _ = False | |
-- A small expression language | |
data Expr a where | |
Plus :: (Num a) => Expr a -> Expr a -> Expr a | |
Less :: (Ord a) => Expr a -> Expr a -> Expr Bool | |
Lift :: a -> Expr a | |
IfTe :: Expr Bool -> Expr a -> Expr a -> Expr a | |
-- ... is equivalent to ... | |
data Expr' a = forall x. (Num x, a ~ x) => Plus' (Expr' x) (Expr' x) | |
| forall x. (Ord x, a ~ Bool) => Less' (Expr' x) (Expr' x) | |
| forall x. (x ~ a) => Lift' x | |
| forall x b. (x ~ a, b ~ Bool) => IfTe' (Expr' b) (Expr' x) (Expr' x) | |
evaluate :: Expr a -> a -- wow! types! | |
evaluate expr = case expr of | |
Plus e1 e2 -> | |
let v1 = evaluate e1 | |
v2 = evaluate e2 | |
in v1 + v2 | |
Less e1 e2 -> | |
let v1 = evaluate e1 | |
v2 = evaluate e2 | |
in v1 < v2 | |
Lift e -> e | |
IfTe e1 e2 e3 -> | |
let v1 = evaluate e1 | |
v2 = evaluate e2 | |
v3 = evaluate e3 | |
in if v1 then v2 else v3 | |
-- witness of equality proof | |
data x :~: y where | |
Refl :: x :~: x | |
-- discharge an equality proof | |
-- notice what happens when you pass undefined | |
coerce :: x :~: y -> x -> y | |
coerce Refl x = x | |
-- witness of a constraint | |
data Dict c a where | |
Dict :: c a => Dict c a | |
-- discharge a constraint proof | |
withDict :: Dict c a -> (c a => b) -> b | |
withDict Dict x = x | |
-- specialized version of Dict | |
-- example to "Show" what's going on | |
data IsShow a where | |
IsShow :: (Show a) => IsShow a | |
-- a weird GADT | |
data IsInt (x :: Bool) where | |
Yep :: (a ~ Int) => a -> IsInt 'True | |
Perhaps :: a -> IsInt 'False | |
-- this is total! | |
getMeAnIntPlease :: IsInt True -> Int | |
getMeAnIntPlease (Yep x) = x | |
-- ordinary value-level Nat | |
data Nat = S Nat | Z | |
-- singleton value-level Nat which witnesses a type-level Nat | |
-- (we get the type-level Nat via DataKinds) | |
data SNat (n :: Nat) where | |
SZ :: SNat Z | |
SS :: SNat n -> SNat (S n) | |
-- addition of SNats which preserves type index properly | |
plus :: SNat n -> SNat m -> SNat (n + m) | |
plus SZ n = n | |
plus (SS n) m = SS (plus n m) | |
-- derive instances for GADTs this way | |
deriving instance Show (SNat n) | |
-- length-indexed list | |
data Vec (n :: Nat) (a :: *) where | |
Nil :: Vec Z a | |
Cons :: a -> Vec n a -> Vec (S n) a | |
deriving instance Show a => Show (Vec n a) | |
-- An excursion: what we would do with DataKinds, but WITHOUT TypeFamilies & GADTs | |
-- Here we use phantom types and MODULE ABSTRACTION to manually verify & enforce invariants | |
data FakeVec (n :: Nat) (a :: *) = FakeVec [a] | |
deriving (Read) -- but OH NO we accidentally everything forever and we are now sad | |
fakeNil :: FakeVec Z a | |
fakeNil = FakeVec [] | |
fakeCons :: a -> FakeVec n a -> FakeVec (S n) a | |
fakeCons a (FakeVec as) = FakeVec (a:as) | |
fakeVecSplit :: FakeVec (S n) a -> (a, FakeVec n a) | |
fakeVecSplit (FakeVec (a:as)) = (a, FakeVec as) | |
fakeVecSplit (FakeVec []) = error "invariant violation: FakeVec is empty!" | |
-- We can violate the invariant by reading a string at the "wrong" phantom type! | |
-- </excursion> | |
-- an example Vec | |
vec1 :: Vec (S (S (S Z))) Char | |
vec1 = Cons 'A' (Cons 'B' (Cons 'C' Nil)) | |
-- this is exhaustive (total)! | |
zipSame :: Vec n a -> Vec n b -> Vec n (a, b) | |
zipSame Nil Nil = Nil | |
zipSame (Cons x xs) (Cons y ys) = Cons (x,y) (zipSame xs ys) | |
-- length of a vector as a singleton Nat (SNat) | |
vecLength :: Vec n a -> SNat n | |
vecLength Nil = SZ | |
vecLength (Cons x xs) = SS $ vecLength xs | |
-- type level addition function (requires TypeFamilies & TypeOperators) | |
type family a + b where | |
Z + n = n | |
S n + m = S (n + m) | |
-- GHC will verify this automatically, because it's the exact same recursion pattern as (+) | |
easyAppend :: Vec m a -> Vec n a -> Vec (m + n) a | |
easyAppend Nil ys = ys | |
easyAppend (Cons x xs) ys = Cons x (easyAppend xs ys) | |
-- But... | |
-- the type level addition '+' is not *automatically* provable | |
-- to be commutative, and hence fails to typecheck... unless we PROVE IT using a lemma. | |
hardAppend :: Vec m a -> Vec n a -> Vec (n + m) a | |
hardAppend v w = | |
case additionCommutative (vecLength v) (vecLength w) of | |
Refl -> easyAppend v w | |
-- sub-lemma: zero is a right neutral for (+) | |
rightNeutral :: SNat n -> n :~: (n + Z) | |
rightNeutral SZ = Refl | |
rightNeutral (SS n) = | |
case rightNeutral n of | |
Refl -> Refl | |
-- sub-lemma: n + S m = S (n + m) | |
plusSucc :: SNat n -> SNat m -> (n + S m) :~: S (n + m) | |
plusSucc SZ _ = Refl | |
plusSucc (SS n) m = | |
case plusSucc n m of | |
Refl -> Refl | |
-- we can use these to prove for any given SNat n, m that addition commutes | |
-- note: this is O(n^2) and MUST BE EXECUTED AT RUN-TIME | |
additionCommutative :: SNat n -> SNat m -> (n + m) :~: (m + n) | |
additionCommutative SZ n = | |
case rightNeutral n of Refl -> Refl | |
additionCommutative (SS m) n = | |
case additionCommutative m n of | |
Refl -> case plusSucc n m of | |
Refl -> Refl | |
-- we can use a type-level minimum function to type-check a truncating zip | |
type family Min (m :: Nat) (n :: Nat) where | |
Min Z y = Z | |
Min x Z = Z | |
Min (S x) (S y) = S (Min x y) | |
-- truncating zip (ala Haskell's ordinary zip) | |
zipMin :: Vec m a -> Vec n b -> Vec (Min m n) (a, b) | |
zipMin Nil _ = Nil | |
zipMin _ Nil = Nil | |
zipMin (Cons x xs) (Cons y ys) = Cons (x, y) (zipMin xs ys) | |
-- Addendum: making things go fast again: | |
-- And here's how we can make things go fast, unsafely | |
-- This is more or less what a dependently typed language can do in some circumstances, | |
-- because if it is total, it knows that it doesn't actually need to run proofs to | |
-- make sure they're not bottom. | |
-- If there's a runtime-costly proof which you are ABSOLUTELY CERTAIN will never be equal | |
-- to bottom (i.e. is the result of a DEFINITELY TOTAL function), you can wrap it in this | |
-- function to avoid ever forcing it and doing the extra work to run the proof. | |
unsafeEraseProof :: forall a b. (a :~: b) -> (a :~: b) | |
unsafeEraseProof _proof = | |
unsafeCoerce Refl :: a :~: b | |
-- append two vectors using a type requiring addition to be commutative, | |
-- but skip actually running the proof at runtime | |
fastHardAppend :: Vec m a -> Vec n a -> Vec (n + m) a | |
fastHardAppend v w = | |
case unsafeEraseProof (additionCommutative (vecLength v) (vecLength w)) of | |
Refl -> easyAppend v w | |
-- existentially quantify over the length of a vector | |
data SomeVec a where | |
SomeVec :: Vec n a -> SomeVec a | |
deriving instance Show a => Show (SomeVec a) | |
-- convert a list into an existential-length-ed vector | |
toVec :: [a] -> SomeVec a | |
toVec [] = SomeVec Nil | |
toVec (x : xs) = | |
case toVec xs of | |
SomeVec xs' -> SomeVec (Cons x xs') | |
-- this will run sloooowwwwwly -- O(n^2) | |
slowHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a | |
slowHardAppendTest (SomeVec x) (SomeVec y) = | |
SomeVec $ hardAppend x y | |
-- this will run quick -- O(n) | |
fastHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a | |
fastHardAppendTest (SomeVec x) (SomeVec y) = | |
SomeVec $ fastHardAppend x y | |
-- for instance, try out: | |
-- > slowHardAppendTest (toVec [0..3000]) (toVec [0..3000]) | |
-- > fastHardAppendTest (toVec [0..3000]) (toVec [0..3000]) | |
-- Notice that there is a noticeable pause before slowHardAppendTest begins printing output. | |
-- This is the time it takes to force the thunk which evaluates to the proof of addition being | |
-- commutative; that is, this entire time is spent evaluating line 185 of this file. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment