Skip to content

Commit

Permalink
Use function type signature for top level lambda collection
Browse files Browse the repository at this point in the history
  • Loading branch information
hmontero1205 committed Nov 26, 2021
1 parent 60d04a0 commit 3b964e9
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions src/IR/LambdaLift.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ import qualified Data.Map as M
import Data.Maybe ( catMaybes )
import qualified Data.Set as S

import Debug.Trace
import Prettyprinter

data LiftCtx = LiftCtx
{ globalScope :: S.Set I.VarId
, currentScope :: S.Set I.VarId
Expand All @@ -42,14 +39,14 @@ newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a)
deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass)

exprType :: I.Expr Poly.Type -> Poly.Type
exprType (I.Lambda _ _ t ) = t
exprType (I.App _ _ t ) = t
exprType (I.Lit _ t ) = t
exprType (I.Var _ t ) = t
exprType (I.Prim _ _ t ) = t
exprType (I.Let _ _ t ) = t
exprType (I.Data _ t ) = t
exprType (I.Match _ _ t) = t
exprType (I.Lambda _ _ t) = t
exprType (I.App _ _ t) = t
exprType (I.Lit _ t ) = t
exprType (I.Var _ t ) = t
exprType (I.Prim _ _ t ) = t
exprType (I.Let _ _ t ) = t
exprType (I.Data _ t ) = t
exprType (I.Match _ _ t ) = t

zipArgsWithArrow :: [Binder] -> Poly.Type -> [(Binder, Poly.Type)]
zipArgsWithArrow (b : bs) (Poly.TBuiltin (Poly.Arrow t ts)) =
Expand Down Expand Up @@ -111,14 +108,14 @@ makeLiftedLambda ((v, t) : vs) body = do

liftLambdas'
:: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type)
liftLambdas' (v, lam@(I.Lambda _ _ t)) = do
liftLambdas' (v, lam@I.Lambda{}) = do
let (vs, body) = I.collectLambda lam
vs' = zipArgsWithArrow vs t
newScope $ catMaybes vs
liftedBody <- liftLambdas body
traceM (show v ++ "and body type " ++ show (pretty $exprType liftedBody))
liftedLambda <- makeLiftedLambda vs' liftedBody
return (v, liftedLambda)
return (v, updateBody lam liftedBody)
where
updateBody (I.Lambda v' b' t') b = I.Lambda v' (updateBody b' b) t'
updateBody _ b = b
liftLambdas' _ = error "Expected top-level lambda binding"

descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type)
Expand All @@ -133,8 +130,7 @@ descend body = do
liftLambdasInArm
:: (I.Alt, I.Expr Poly.Type) -> LiftFn (I.Alt, I.Expr Poly.Type)
liftLambdasInArm (I.AltLit l, arm) = do
(liftedArm, armFrees) <-
descend $ return () >> liftLambdas arm
(liftedArm, armFrees) <- descend $ liftLambdas arm
modify $ \st -> st { freeTypes = armFrees }
return (I.AltLit l, liftedArm)
liftLambdasInArm (I.AltDefault b, arm) = do
Expand Down

0 comments on commit 3b964e9

Please sign in to comment.