Skip to content

Instantly share code, notes, and snippets.

@plaidfinch
Last active November 15, 2015 04:14
Show Gist options
  • Save plaidfinch/1b4e227e476353e775fe to your computer and use it in GitHub Desktop.
Save plaidfinch/1b4e227e476353e775fe to your computer and use it in GitHub Desktop.
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TeachingGADTs where
import Control.Applicative
import Control.Monad
import Unsafe.Coerce
-- Motivation: how would you write a "typed" expression language and its evaluator?
-- We use GADT *syntax* here but we don't use GADT *semantic* features
-- introducing GADT syntax: here's how you'd do Maybe
data Maybe' a where
Just' :: a -> Maybe' a
Nothing' :: Maybe' a
data Exp where
PlusExp :: Exp -> Exp -> Exp
LessExp :: Exp -> Exp -> Exp
IntExp :: Int -> Exp
BoolExp :: Bool -> Exp
IfTeExp :: Exp -> Exp -> Exp -> Exp
-- type-checking happens at run-time during evaluation, and we can easily get it wrong
eval :: Exp -> Maybe (Either Int Bool) -- we don't get compile-time type-checking for our DSL
eval (PlusExp e1 e2) = do Left i1 <- eval e1
Left i2 <- eval e2
pure . Left $ i1 + i2
eval (LessExp e1 e2) = do Left i1 <- eval e1
Left i2 <- eval e2
pure . Right $ i1 < i2
eval (IntExp i) = Just $ Left i
eval (BoolExp b) = Just $ Right b
eval (IfTeExp e1 e2 e3) = do Right b <- eval e1
v2 <- eval e2
v3 <- eval e3
guard (bothLeft v2 v3 || bothRight v2 v3) -- this line would be easy to forget!
pure $ if b then v2 else v3
where bothLeft (Left _) (Left _) = True
bothLeft _ _ = False
bothRight (Right _) (Right _) = True
bothRight _ _ = False
-- A small expression language
data Expr a where
Plus :: (Num a) => Expr a -> Expr a -> Expr a
Less :: (Ord a) => Expr a -> Expr a -> Expr Bool
Lift :: a -> Expr a
IfTe :: Expr Bool -> Expr a -> Expr a -> Expr a
-- ... is equivalent to ...
data Expr' a = forall x. (Num x, a ~ x) => Plus' (Expr' x) (Expr' x)
| forall x. (Ord x, a ~ Bool) => Less' (Expr' x) (Expr' x)
| forall x. (x ~ a) => Lift' x
| forall x b. (x ~ a, b ~ Bool) => IfTe' (Expr' b) (Expr' x) (Expr' x)
evaluate :: Expr a -> a -- wow! types!
evaluate expr = case expr of
Plus e1 e2 ->
let v1 = evaluate e1
v2 = evaluate e2
in v1 + v2
Less e1 e2 ->
let v1 = evaluate e1
v2 = evaluate e2
in v1 < v2
Lift e -> e
IfTe e1 e2 e3 ->
let v1 = evaluate e1
v2 = evaluate e2
v3 = evaluate e3
in if v1 then v2 else v3
-- witness of equality proof
data x :~: y where
Refl :: x :~: x
-- discharge an equality proof
-- notice what happens when you pass undefined
coerce :: x :~: y -> x -> y
coerce Refl x = x
-- witness of a constraint
data Dict c a where
Dict :: c a => Dict c a
-- discharge a constraint proof
withDict :: Dict c a -> (c a => b) -> b
withDict Dict x = x
-- specialized version of Dict
-- example to "Show" what's going on
data IsShow a where
IsShow :: (Show a) => IsShow a
-- a weird GADT
data IsInt (x :: Bool) where
Yep :: (a ~ Int) => a -> IsInt 'True
Perhaps :: a -> IsInt 'False
-- this is total!
getMeAnIntPlease :: IsInt True -> Int
getMeAnIntPlease (Yep x) = x
-- ordinary value-level Nat
data Nat = S Nat | Z
-- singleton value-level Nat which witnesses a type-level Nat
-- (we get the type-level Nat via DataKinds)
data SNat (n :: Nat) where
SZ :: SNat Z
SS :: SNat n -> SNat (S n)
-- addition of SNats which preserves type index properly
plus :: SNat n -> SNat m -> SNat (n + m)
plus SZ n = n
plus (SS n) m = SS (plus n m)
-- derive instances for GADTs this way
deriving instance Show (SNat n)
-- length-indexed list
data Vec (n :: Nat) (a :: *) where
Nil :: Vec Z a
Cons :: a -> Vec n a -> Vec (S n) a
deriving instance Show a => Show (Vec n a)
-- An excursion: what we would do with DataKinds, but WITHOUT TypeFamilies & GADTs
-- Here we use phantom types and MODULE ABSTRACTION to manually verify & enforce invariants
data FakeVec (n :: Nat) (a :: *) = FakeVec [a]
deriving (Read) -- but OH NO we accidentally everything forever and we are now sad
fakeNil :: FakeVec Z a
fakeNil = FakeVec []
fakeCons :: a -> FakeVec n a -> FakeVec (S n) a
fakeCons a (FakeVec as) = FakeVec (a:as)
fakeVecSplit :: FakeVec (S n) a -> (a, FakeVec n a)
fakeVecSplit (FakeVec (a:as)) = (a, FakeVec as)
fakeVecSplit (FakeVec []) = error "invariant violation: FakeVec is empty!"
-- We can violate the invariant by reading a string at the "wrong" phantom type!
-- </excursion>
-- an example Vec
vec1 :: Vec (S (S (S Z))) Char
vec1 = Cons 'A' (Cons 'B' (Cons 'C' Nil))
-- this is exhaustive (total)!
zipSame :: Vec n a -> Vec n b -> Vec n (a, b)
zipSame Nil Nil = Nil
zipSame (Cons x xs) (Cons y ys) = Cons (x,y) (zipSame xs ys)
-- length of a vector as a singleton Nat (SNat)
vecLength :: Vec n a -> SNat n
vecLength Nil = SZ
vecLength (Cons x xs) = SS $ vecLength xs
-- type level addition function (requires TypeFamilies & TypeOperators)
type family a + b where
Z + n = n
S n + m = S (n + m)
-- GHC will verify this automatically, because it's the exact same recursion pattern as (+)
easyAppend :: Vec m a -> Vec n a -> Vec (m + n) a
easyAppend Nil ys = ys
easyAppend (Cons x xs) ys = Cons x (easyAppend xs ys)
-- But...
-- the type level addition '+' is not *automatically* provable
-- to be commutative, and hence fails to typecheck... unless we PROVE IT using a lemma.
hardAppend :: Vec m a -> Vec n a -> Vec (n + m) a
hardAppend v w =
case additionCommutative (vecLength v) (vecLength w) of
Refl -> easyAppend v w
-- sub-lemma: zero is a right neutral for (+)
rightNeutral :: SNat n -> n :~: (n + Z)
rightNeutral SZ = Refl
rightNeutral (SS n) =
case rightNeutral n of
Refl -> Refl
-- sub-lemma: n + S m = S (n + m)
plusSucc :: SNat n -> SNat m -> (n + S m) :~: S (n + m)
plusSucc SZ _ = Refl
plusSucc (SS n) m =
case plusSucc n m of
Refl -> Refl
-- we can use these to prove for any given SNat n, m that addition commutes
-- note: this is O(n^2) and MUST BE EXECUTED AT RUN-TIME
additionCommutative :: SNat n -> SNat m -> (n + m) :~: (m + n)
additionCommutative SZ n =
case rightNeutral n of Refl -> Refl
additionCommutative (SS m) n =
case additionCommutative m n of
Refl -> case plusSucc n m of
Refl -> Refl
-- we can use a type-level minimum function to type-check a truncating zip
type family Min (m :: Nat) (n :: Nat) where
Min Z y = Z
Min x Z = Z
Min (S x) (S y) = S (Min x y)
-- truncating zip (ala Haskell's ordinary zip)
zipMin :: Vec m a -> Vec n b -> Vec (Min m n) (a, b)
zipMin Nil _ = Nil
zipMin _ Nil = Nil
zipMin (Cons x xs) (Cons y ys) = Cons (x, y) (zipMin xs ys)
-- Addendum: making things go fast again:
-- And here's how we can make things go fast, unsafely
-- This is more or less what a dependently typed language can do in some circumstances,
-- because if it is total, it knows that it doesn't actually need to run proofs to
-- make sure they're not bottom.
-- If there's a runtime-costly proof which you are ABSOLUTELY CERTAIN will never be equal
-- to bottom (i.e. is the result of a DEFINITELY TOTAL function), you can wrap it in this
-- function to avoid ever forcing it and doing the extra work to run the proof.
unsafeEraseProof :: forall a b. (a :~: b) -> (a :~: b)
unsafeEraseProof _proof =
unsafeCoerce Refl :: a :~: b
-- append two vectors using a type requiring addition to be commutative,
-- but skip actually running the proof at runtime
fastHardAppend :: Vec m a -> Vec n a -> Vec (n + m) a
fastHardAppend v w =
case unsafeEraseProof (additionCommutative (vecLength v) (vecLength w)) of
Refl -> easyAppend v w
-- existentially quantify over the length of a vector
data SomeVec a where
SomeVec :: Vec n a -> SomeVec a
deriving instance Show a => Show (SomeVec a)
-- convert a list into an existential-length-ed vector
toVec :: [a] -> SomeVec a
toVec [] = SomeVec Nil
toVec (x : xs) =
case toVec xs of
SomeVec xs' -> SomeVec (Cons x xs')
-- this will run sloooowwwwwly -- O(n^2)
slowHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a
slowHardAppendTest (SomeVec x) (SomeVec y) =
SomeVec $ hardAppend x y
-- this will run quick -- O(n)
fastHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a
fastHardAppendTest (SomeVec x) (SomeVec y) =
SomeVec $ fastHardAppend x y
-- for instance, try out:
-- > slowHardAppendTest (toVec [0..3000]) (toVec [0..3000])
-- > fastHardAppendTest (toVec [0..3000]) (toVec [0..3000])
-- Notice that there is a noticeable pause before slowHardAppendTest begins printing output.
-- This is the time it takes to force the thunk which evaluates to the proof of addition being
-- commutative; that is, this entire time is spent evaluating line 185 of this file.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment