Last active
August 23, 2024 00:18
-
-
Save AndrasKovacs/7f81ac652052829809611236b018442e to your computer and use it in GitHub Desktop.
HOAS-only lambda calculus
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# language BlockArguments, LambdaCase, Strict, UnicodeSyntax #-} | |
{-| | |
Minimal dependent lambda caluclus with: | |
- HOAS-only representation | |
- Lossless printing | |
- Bidirectional checking | |
- Efficient evaluation & conversion checking | |
Inspired by https://gist.github.com/Hirrolot/27e6b02a051df333811a23b97c375196 | |
Discussion: https://www.reddit.com/r/ProgrammingLanguages/comments/11bo39o/how_to_implement_dependent_types_in_80_lines_of/ | |
The idea is to always compute two semantics at the same time: | |
- A strict one which computes no redexes. | |
- A lazy one which computes all redexes. | |
-} | |
import Prelude hiding (pi) | |
data Val | |
= Var Int ~Val | |
| App Val Val ~Val | |
| Lam (Val -> Val) ~Val | |
| Let Val Val (Val -> Val) ~Val | |
| Pi Val (Val -> Val) ~Val | |
| U | |
type Lvl = Int | |
type Ty = Val | |
eval :: Val -> Val | |
eval = \case | |
Var _ v -> v | |
App _ _ v -> v | |
Lam _ v -> v | |
Let _ _ _ v -> v | |
Pi _ _ v -> v | |
U -> U | |
fix f = let x = f x in x; {-# inline fix #-} | |
rigid l = fix (Var l) | |
let_ t a u = Let t a u (eval (u (eval t))) | |
pi a b = Pi a b (fix (Pi (eval a) (eval . b))) | |
lam t = Lam t (fix (Lam (eval . t))) | |
infixl 8 ∙ | |
(∙) t u = App t u (case eval t of | |
Lam f _ -> f (eval u) | |
t -> fix (App t (eval u))) | |
infixr 4 ==> | |
(==>) a b = pi a \_ -> b | |
instance Show Val where | |
show t = go 0 t "" where | |
go l t = case t of | |
Var x _ -> (show x ++) | |
App t u _ -> ('(':).go l t.(' ':).go l u.(')':) | |
Lam t _ -> ("(λ "++).(show l++).(". "++).go (l + 1) (t (rigid l)).(')':) | |
Let t a u _ -> ("(let "++).(show l++).(" : "++) | |
.go l a.(" = "++).go l t.("; "++).go (l+1) (u (rigid l)).(')':) | |
Pi a b _ -> ("(("++).(show l++).(" : "++).go l a.(") -> "++).go (l+1) (b (rigid l)).(')':) | |
U -> ('U':) | |
conv :: Lvl -> Val -> Val -> Bool | |
conv l t t' = case (t, t') of | |
(Var x _ , Var x' _ ) -> x == x' | |
(App t u _ , App t' u' _ ) -> conv l t t' && conv l u u' | |
(Lam t _ , Lam t' _ ) -> let v = rigid l in conv (l + 1) (t v) (t' v) | |
(Lam t _ , t' ) -> let v = rigid l in conv (l + 1) (t v) (t' ∙ v) | |
(t , Lam t' _ ) -> let v = rigid l in conv (l + 1) (t ∙ v) (t' v) | |
(Pi a b _ , Pi a' b' _ ) -> let v = rigid l in conv l a a' && conv (l + 1) (b v) (b' v) | |
(U , U ) -> True | |
_ -> False | |
infer :: Int -> [Ty] -> Val -> Maybe Ty | |
infer l cxt = \case | |
Var x _ -> pure $! cxt !! (l - x - 1) | |
App t u _ -> do a <- infer l cxt t | |
case a of | |
Pi a b _ -> check l cxt u a >> pure (b (eval u)) | |
_ -> Nothing | |
Lam _ v -> Nothing | |
Let t a u _ -> do check l cxt a U | |
let va = eval a | |
check l cxt t va | |
infer (l+1) (va:cxt) (u (Var l (eval t))) | |
Pi a b _ -> do check l cxt a U | |
check (l+1) (eval a:cxt) (b (rigid l)) U | |
pure U | |
U -> pure U | |
check :: Int -> [Ty] -> Val -> Ty -> Maybe () | |
check l cxt t a = case (t, a) of | |
(Lam t _, Pi a b _) -> let v = rigid l in check (l+1) (a:cxt) (t v) (b v) | |
(t , a ) -> do a' <- infer l cxt t | |
if conv l a a' then pure () | |
else Nothing | |
-------------------------------------------------------------------------------- | |
test = | |
let_ (pi U \a -> (a ==> a) ==> a ==> a) | |
U \nat -> | |
let_ (lam \p -> lam \s -> lam \z -> z) | |
nat \zero -> | |
let_ (lam \n -> lam \p -> lam \s -> lam \z -> s ∙ (n ∙ p ∙ s ∙ z)) | |
(nat ==> nat) \suc -> | |
let_ (lam \a -> lam \b -> lam \p -> lam \s -> lam \z -> a ∙ p ∙ s ∙ (b ∙ p ∙ s ∙ z)) | |
(nat ==> nat ==> nat) \add -> | |
let_ (lam \a -> lam \b -> lam \p -> lam \s -> lam \z -> a ∙ p ∙ (b ∙ p ∙ s) ∙ z) | |
(nat ==> nat ==> nat) \mul -> | |
let_ (lam \p -> lam \s -> lam \z -> s ∙ (s ∙ (s ∙ (s ∙ (s ∙ z))))) | |
nat \n5 -> | |
let_ (add ∙ n5 ∙ n5) | |
nat \n10 -> | |
let_ (mul ∙ n10 ∙ n10) | |
nat \n100 -> | |
let_ (mul ∙ n100 ∙ n10) | |
nat \n1000 -> | |
n1000 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment