Skip to content

Instantly share code, notes, and snippets.

@ChrisPenner
Last active February 9, 2022 19:52
Show Gist options
  • Save ChrisPenner/b1d88dc96912b1bb7b5cb9e7409d5277 to your computer and use it in GitHub Desktop.
Save ChrisPenner/b1d88dc96912b1bb7b5cb9e7409d5277 to your computer and use it in GitHub Desktop.
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE RankNTypes #-}
module Recurser where
import Control.Lens
import Data.Monoid
import Data.Foldable
import Data.Functor.Contravariant
type Name = String
data Type = IntegerType | FunType Type Type
deriving Show
data Term =
Var Name
| Lam Name Type Term
| App Term Term
| Plus Term Term
| Constant Integer
deriving Show
cf :: Term -> Term
cf = \case
Plus (Constant i1) (Constant i2) -> Constant (i1 + i2)
x -> x
termsS :: Monad m => (Term -> m Term) -> Term -> m Term
termsS f term = f =<< case term of
Lam n ty t -> (Lam n ty <$> (termsS f t))
App t1 t2 -> (App <$> termsS f t1 <*> termsS f t2)
Plus t1 t2 -> (Plus <$> termsS f t1 <*> termsS f t2)
x -> pure x
-- or simply:
-- termsS = transformM
termsF :: Fold Term Term
termsF f term = f term *> case term of
Lam n ty t -> termsF f t
App t1 t2 -> termsF f t1 *> termsF f t2
Plus t1 t2 -> termsF f t1 *> termsF f t2
x -> pure x
-- or simply
-- termsF = cosmos
exampleTerm :: Term
exampleTerm = Lam "Add" IntegerType
$ Plus (Plus (Constant 1) (Constant 2)) (Constant 3)
flattenConsts :: Term -> Term
flattenConsts = over termsS cf
-- Or:
-- flattenConsts = transform cf
countSubterms :: Term -> Int
countSubterms = lengthOf termsF
termTypesF :: Fold Term Type
termTypesF f t = case t of
Lam _ ty _ -> phantom (typesF f ty)
x -> pure x
-- or:
-- biplate . cosmos
typesF :: Fold Type Type
typesF f t = case t of
FunType ty1 ty2 -> f t *> f ty1 *> f ty2
x -> f x
-- Or:
-- typesF = cosmos
countTermNodes :: Term -> Int
countTermNodes =
lengthOf (termsF . (united <> termTypesF . united))
-- Or perhaps the clearer version:
countTermNodes' :: Term -> Sum Int
countTermNodes' =
foldOf (termsF . to (\term -> Sum (1 + lengthOf termTypesF term)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment