From ecc913e8ff4fed7046db048d13af8792ce222e4d Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Thu, 4 Nov 2021 17:38:48 -0400 Subject: [PATCH 01/40] WIP --- app/Main.hs | 2 ++ src/IR.hs | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/app/Main.hs b/app/Main.hs index 350b0f5f..99a98408 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -129,6 +129,8 @@ main = do irL <- doPass $ IR.lambdaLift irY + when (True) $ putStrLn "DFGSDFDSF" >> exitSuccess + irI <- doPass $ IR.defunctionalize irL irD <- doPass $ IR.inferDrops irI diff --git a/src/IR.hs b/src/IR.hs index 6d0a44d4..e04512b7 100644 --- a/src/IR.hs +++ b/src/IR.hs @@ -15,6 +15,7 @@ import IR.ClassInstantiation ( instProgram ) import IR.LowerAst ( lowerProgram ) import IR.Monomorphize ( monoProgram ) import IR.TypeInference ( inferProgram ) +import IR.LambdaLift ( liftProgramLambdas ) lowerAst :: A.Program -> Compiler.Pass (I.Program Ann.Type) lowerAst = lowerProgram @@ -30,7 +31,7 @@ yieldAbstraction :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) yieldAbstraction = return lambdaLift :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) -lambdaLift = return +lambdaLift = liftProgramLambdas defunctionalize :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) defunctionalize = return From 713b83975789bb0994530dede710caa9da72afb0 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Thu, 4 Nov 2021 20:01:34 -0400 Subject: [PATCH 02/40] mvp: detect free vars --- src/IR/LambdaLift.hs | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/IR/LambdaLift.hs diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs new file mode 100644 index 00000000..bf3a96c0 --- /dev/null +++ b/src/IR/LambdaLift.hs @@ -0,0 +1,37 @@ +module IR.LambdaLift where + +import qualified Common.Compiler as Compiler +import qualified IR.IR as I + +import qualified IR.Types.Poly as Poly + +import Debug.Trace + +import qualified Data.Set as S +import Data.List (intercalate) + + + + + +liftProgramLambdas + :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) +liftProgramLambdas p = do + let defs = I.programDefs p + globalScope = map (\(v, _) -> show v) defs + funs = filter isFun defs + freeVars = map (getFrees (S.fromList globalScope) (S.fromList globalScope)) (map snd funs) + traceM (show $ zip (map (\(v, _) -> v) funs) freeVars) + return p + where + isFun (_, I.Lambda _ _ _) = True + isFun _ = False + getFrees scp gs (I.Var v _) = if S.member (show v) scp then [] else [show v] + getFrees scp gs (I.App e1 e2 _) = getFrees scp gs e1 ++ getFrees scp gs e2 + getFrees scp gs (I.Let binds e _) = let newScp = foldl (\s (Just v, _) -> S.insert (show v) s) scp binds in + (concatMap (\(_, be) -> getFrees newScp gs be) binds) ++ getFrees newScp gs e + getFrees scp gs (I.Lambda (Just v) e _) = let newScp = S.insert (show v) gs in + trace (intercalate "\n" (getFrees newScp gs e) ++ "\n--") [] + getFrees scp gs (I.Match _ _ _ _) = undefined + getFrees scp gs (I.Prim _ es _) = concatMap (getFrees scp gs) es + getFrees _ _ _ = [] From eb3720b3a476159dcbc97d562bdde76e259c4e0e Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Mon, 8 Nov 2021 18:08:23 -0500 Subject: [PATCH 03/40] Improve free variable collection --- src/IR/LambdaLift.hs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index bf3a96c0..2a1c0bdd 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -18,7 +18,7 @@ liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) liftProgramLambdas p = do let defs = I.programDefs p - globalScope = map (\(v, _) -> show v) defs + globalScope = map (\(v, _) -> v) defs funs = filter isFun defs freeVars = map (getFrees (S.fromList globalScope) (S.fromList globalScope)) (map snd funs) traceM (show $ zip (map (\(v, _) -> v) funs) freeVars) @@ -26,12 +26,17 @@ liftProgramLambdas p = do where isFun (_, I.Lambda _ _ _) = True isFun _ = False - getFrees scp gs (I.Var v _) = if S.member (show v) scp then [] else [show v] + getFrees scp gs (I.Var v _) = if S.member (v) scp then [] else [show v] getFrees scp gs (I.App e1 e2 _) = getFrees scp gs e1 ++ getFrees scp gs e2 - getFrees scp gs (I.Let binds e _) = let newScp = foldl (\s (Just v, _) -> S.insert (show v) s) scp binds in - (concatMap (\(_, be) -> getFrees newScp gs be) binds) ++ getFrees newScp gs e - getFrees scp gs (I.Lambda (Just v) e _) = let newScp = S.insert (show v) gs in - trace (intercalate "\n" (getFrees newScp gs e) ++ "\n--") [] + getFrees scp gs (I.Let binds e _) = let newScp = foldl (\s (Just v, _) -> S.insert (v) s) scp binds in + (concatMap (getLetFrees newScp gs) binds) ++ getFrees newScp gs e + getFrees scp gs lam@(I.Lambda _ _ _) = let (vs, body) = I.collectLambda lam + newScp = foldl (\s (Just v) -> S.insert v s) gs vs in + trace (intercalate "\n" (getFrees newScp gs body) ++ "\n--") [] getFrees scp gs (I.Match _ _ _ _) = undefined getFrees scp gs (I.Prim _ es _) = concatMap (getFrees scp gs) es getFrees _ _ _ = [] + getLetFrees scp gs (Just v, lam@(I.Lambda _ _ _ )) = let (vs, body) = I.collectLambda lam + newScp = foldl (\s (Just v) -> S.insert v s) gs vs in + trace (show v ++ ":\n" ++ (intercalate "\n" (getFrees newScp gs body) ++ "\n--")) [] + getLetFrees scp gs (Just _, e) = getFrees scp gs e From c5f7211d519753e78e96e1abe4b2bb3f8d0ee285 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 9 Nov 2021 15:49:04 -0500 Subject: [PATCH 04/40] wip: Lift lambdas using state monad --- src/IR/LambdaLift.hs | 97 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 2a1c0bdd..5e758dc5 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -1,16 +1,113 @@ +{-# LANGUAGE DerivingVia #-} module IR.LambdaLift where import qualified Common.Compiler as Compiler +import Common.Identifiers import qualified IR.IR as I import qualified IR.Types.Poly as Poly import Debug.Trace +import Control.Monad.Except ( MonadError(..) ) +import Control.Monad.State.Lazy ( MonadState + , StateT(..) + , evalStateT + , gets + , get + , modify + ) import qualified Data.Set as S +import qualified Data.Map as M import Data.List (intercalate) +data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId + , currentScope :: S.Set I.VarId + , currentFrees :: S.Set I.VarId + , lifted :: [(I.VarId, I.Expr Poly.Type)] + , toAdjust :: M.Map I.VarId I.VarId + , anonCount :: Int + } +newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) + deriving Functor via (StateT LiftCtx Compiler.Pass) + deriving Applicative via (StateT LiftCtx Compiler.Pass) + deriving Monad via (StateT LiftCtx Compiler.Pass) + deriving MonadFail via (StateT LiftCtx Compiler.Pass) + deriving (MonadError Compiler.Error) via (StateT LiftCtx Compiler.Pass) + deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass) + +runLiftFn :: LiftFn a -> Compiler.Pass a +runLiftFn (LiftFn m) = evalStateT m LiftCtx { globalScope = S.empty, currentScope = S.empty, currentFrees = S.empty, lifted = [], toAdjust = M.empty, anonCount = 0 } + +populateGlobalScope :: [(I.VarId, I.Expr Poly.Type)] -> LiftFn () +populateGlobalScope defs = do + let globalNames = map (\(v, _) -> v) defs + modify $ \st -> st { globalScope = S.fromList globalNames } + +inCurrentScope :: I.VarId -> LiftFn Bool +inCurrentScope v = S.member v <$> gets currentScope + +getFresh :: LiftFn Int +getFresh = do + curCount <- gets anonCount + modify $ \st -> st { anonCount = anonCount st + 1 } + return curCount + +addLifted :: String -> I.Expr Poly.Type -> LiftFn () +addLifted name lam = modify $ \st -> st { lifted = ((I.VarId (Identifier name)), lam) : lifted st } + +addFreeVar :: I.VarId -> LiftFn () +addFreeVar v = modify $ \st -> st { currentFrees = S.insert v $ currentFrees st } + +newScope :: [I.VarId] -> LiftFn () +newScope vs = modify $ \st -> st { currentScope = S.union (globalScope st) (S.fromList vs), currentFrees = S.empty } + +makeLiftedLambda :: [I.Binder] -> I.Expr Poly.Type -> Poly.Type -> LiftFn (I.Expr Poly.Type) +makeLiftedLambda [] body _ = return body +makeLiftedLambda vs body t = do + liftedBody <- makeLiftedLambda (tail vs) body t + return (I.Lambda (head vs) liftedBody t) + +liftLambdas' :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) +liftLambdas' e = do + newScope [] + liftLambdas e + +liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) +liftLambdas n@(I.Var v _) = do + isNotFree <- inCurrentScope v + if isNotFree then return n + else do addFreeVar v + return n +liftLambdas (I.App e1 e2 t) = do + liftedE1 <- liftLambdas e1 + liftedE2 <- liftLambdas e2 + return $ I.App liftedE1 liftedE2 t +liftLambdas (I.Prim p exprs t) = do + liftedExprs <- mapM liftLambdas' exprs + return $ I.Prim p liftedExprs t +liftLambdas lam@(I.Lambda _ _ t) = do + let (vs, body) = I.collectLambda lam + oldCtx <- get + newScope $ map (\(Just v) -> v) vs + liftedLamBody <- liftLambdas body + lamFrees <- gets currentFrees + liftedLam <- makeLiftedLambda (map (Just) (S.toList lamFrees) ++ vs) body t + freshNum <- getFresh + addLifted ("anon" ++ (show freshNum)) liftedLam + return (foldl (\app v -> I.App app (I.Var v t) t) (I.Var (I.VarId (Identifier ("anon" ++ (show freshNum)))) t) (S.toList lamFrees)) + +liftProgramLambdas' :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) +liftProgramLambdas' p = runLiftFn $ do + let defs = I.programDefs p + funs = map snd $ filter isFun defs + populateGlobalScope defs + funsWithoutLambdas <- mapM liftLambdas funs + return p + where + isFun (_, I.Lambda _ _ _) = True + isFun _ = False From 748d918db8484576497eb16a4cf76c8d6518240b Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 9 Nov 2021 17:31:04 -0500 Subject: [PATCH 05/40] wip: Lift let expressions, start testing, but still no callsite adjustments for named lambdas --- app/Main.hs | 2 +- src/IR/LambdaLift.hs | 50 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/app/Main.hs b/app/Main.hs index 99a98408..0b830f15 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -129,7 +129,7 @@ main = do irL <- doPass $ IR.lambdaLift irY - when (True) $ putStrLn "DFGSDFDSF" >> exitSuccess + when (True) $ putStrLn (spaghetti irL) >> exitSuccess irI <- doPass $ IR.defunctionalize irL diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 5e758dc5..c3ac8a5a 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -20,6 +20,7 @@ import Control.Monad.State.Lazy ( MonadState import qualified Data.Set as S import qualified Data.Map as M import Data.List (intercalate) +import Data.Maybe (isJust) data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId , currentScope :: S.Set I.VarId @@ -48,6 +49,9 @@ populateGlobalScope defs = do inCurrentScope :: I.VarId -> LiftFn Bool inCurrentScope v = S.member v <$> gets currentScope +addCurrentScope :: I.VarId -> LiftFn () +addCurrentScope v = modify $ \st -> st { currentScope = S.insert v $ currentScope st } + getFresh :: LiftFn Int getFresh = do curCount <- gets anonCount @@ -67,6 +71,7 @@ makeLiftedLambda :: [I.Binder] -> I.Expr Poly.Type -> Poly.Type -> LiftFn (I.Exp makeLiftedLambda [] body _ = return body makeLiftedLambda vs body t = do liftedBody <- makeLiftedLambda (tail vs) body t + traceM (show vs) return (I.Lambda (head vs) liftedBody t) liftLambdas' :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) @@ -74,6 +79,23 @@ liftLambdas' e = do newScope [] liftLambdas e +liftLetBinding :: (I.VarId, I.Expr Poly.Type) -> LiftFn (Maybe (I.VarId, I.Expr Poly.Type)) +liftLetBinding (v, lam@(I.Lambda _ _ t)) = do + let (vs, body) = I.collectLambda lam + oldCtx <- get + newScope $ map (\(Just vi) -> vi) vs + liftedLamBody <- liftLambdas body + lamFrees <- gets currentFrees + liftedLam <- makeLiftedLambda (map Just (S.toList lamFrees) ++ vs) liftedLamBody t + freshNum <- getFresh + addLifted (show v ++ "_lifted_" ++ show freshNum) liftedLam + modify $ \st -> st { currentScope = currentScope oldCtx, currentFrees = currentScope oldCtx } + return Nothing + +liftLetBinding (v, e) = do + liftedBody <- liftLambdas e + return $ Just (v, liftedBody) + liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v _) = do isNotFree <- inCurrentScope v @@ -93,24 +115,35 @@ liftLambdas lam@(I.Lambda _ _ t) = do newScope $ map (\(Just v) -> v) vs liftedLamBody <- liftLambdas body lamFrees <- gets currentFrees - liftedLam <- makeLiftedLambda (map (Just) (S.toList lamFrees) ++ vs) body t + liftedLam <- makeLiftedLambda (map (Just) (S.toList lamFrees) ++ vs) liftedLamBody t freshNum <- getFresh addLifted ("anon" ++ (show freshNum)) liftedLam + modify $ \st -> st { currentScope = currentScope oldCtx, currentFrees = currentScope oldCtx } return (foldl (\app v -> I.App app (I.Var v t) t) (I.Var (I.VarId (Identifier ("anon" ++ (show freshNum)))) t) (S.toList lamFrees)) - -liftProgramLambdas' :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) -liftProgramLambdas' p = runLiftFn $ do +liftLambdas lbs@(I.Let bs e t) = do + let vs = map (\(Just v, _) -> v) bs + exprs = map (\(_, e) -> e) bs + mapM addCurrentScope vs + liftedBindings <- mapM liftLetBinding (zip vs exprs) + liftedExpr <- liftLambdas e + return $ I.Let (((map (\(Just (v, e)) -> (Just v, e))) . (filter isJust)) liftedBindings) (liftedExpr) t +liftLambdas n = return n + +liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) +liftProgramLambdas p = runLiftFn $ do let defs = I.programDefs p funs = map snd $ filter isFun defs + funNames = map fst $ filter isFun defs + oths = filter (not . isFun) defs populateGlobalScope defs funsWithoutLambdas <- mapM liftLambdas funs - return p + liftedLambdas <- gets lifted + traceM "finished Lifting" + return $ p { I.programDefs = (liftedLambdas ++ (zip funNames funsWithoutLambdas) ++ oths) } where isFun (_, I.Lambda _ _ _) = True isFun _ = False - - - + {- liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) liftProgramLambdas p = do @@ -137,3 +170,4 @@ liftProgramLambdas p = do newScp = foldl (\s (Just v) -> S.insert v s) gs vs in trace (show v ++ ":\n" ++ (intercalate "\n" (getFrees newScp gs body) ++ "\n--")) [] getLetFrees scp gs (Just _, e) = getFrees scp gs e +-} From 06e6fbd0c9321751ce5773f88281839225b75a71 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Wed, 10 Nov 2021 02:19:12 -0500 Subject: [PATCH 06/40] Use a jumping off function for lambda lifting --- src/IR/LambdaLift.hs | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index c3ac8a5a..49ada28e 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -17,6 +17,7 @@ import Control.Monad.State.Lazy ( MonadState , modify ) +import Prettyprinter import qualified Data.Set as S import qualified Data.Map as M import Data.List (intercalate) @@ -74,14 +75,17 @@ makeLiftedLambda vs body t = do traceM (show vs) return (I.Lambda (head vs) liftedBody t) -liftLambdas' :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) -liftLambdas' e = do - newScope [] - liftLambdas e +liftLambdas' :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) +liftLambdas' (v, lam@(I.Lambda _ _ t)) = do + let (vs, body) = I.collectLambda lam + newScope $ map (\(Just vi) -> vi) vs + liftedBody <- liftLambdas body + return $ (v, foldl (\lam' v' -> (I.Lambda v' lam' t)) liftedBody vs) liftLetBinding :: (I.VarId, I.Expr Poly.Type) -> LiftFn (Maybe (I.VarId, I.Expr Poly.Type)) liftLetBinding (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam + traceM "lamda let" oldCtx <- get newScope $ map (\(Just vi) -> vi) vs liftedLamBody <- liftLambdas body @@ -91,8 +95,8 @@ liftLetBinding (v, lam@(I.Lambda _ _ t)) = do addLifted (show v ++ "_lifted_" ++ show freshNum) liftedLam modify $ \st -> st { currentScope = currentScope oldCtx, currentFrees = currentScope oldCtx } return Nothing - liftLetBinding (v, e) = do + traceM "non-lambda let" liftedBody <- liftLambdas e return $ Just (v, liftedBody) @@ -106,11 +110,13 @@ liftLambdas (I.App e1 e2 t) = do liftedE1 <- liftLambdas e1 liftedE2 <- liftLambdas e2 return $ I.App liftedE1 liftedE2 t -liftLambdas (I.Prim p exprs t) = do - liftedExprs <- mapM liftLambdas' exprs +liftLambdas a@(I.Prim p [l, r] t) = do + liftedExprs <- mapM liftLambdas [l, r] + traceM (show $ pretty l) -- what is happening? return $ I.Prim p liftedExprs t liftLambdas lam@(I.Lambda _ _ t) = do let (vs, body) = I.collectLambda lam + traceM "Lambda" oldCtx <- get newScope $ map (\(Just v) -> v) vs liftedLamBody <- liftLambdas body @@ -122,24 +128,25 @@ liftLambdas lam@(I.Lambda _ _ t) = do return (foldl (\app v -> I.App app (I.Var v t) t) (I.Var (I.VarId (Identifier ("anon" ++ (show freshNum)))) t) (S.toList lamFrees)) liftLambdas lbs@(I.Let bs e t) = do let vs = map (\(Just v, _) -> v) bs - exprs = map (\(_, e) -> e) bs - mapM addCurrentScope vs + exprs = map (\(_, e') -> e') bs + traceM "Let" + mapM_ addCurrentScope vs liftedBindings <- mapM liftLetBinding (zip vs exprs) liftedExpr <- liftLambdas e - return $ I.Let (((map (\(Just (v, e)) -> (Just v, e))) . (filter isJust)) liftedBindings) (liftedExpr) t + return $ I.Let (((map (\(Just (v, e')) -> (Just v, e'))) . (filter isJust)) liftedBindings) (liftedExpr) t liftLambdas n = return n liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) liftProgramLambdas p = runLiftFn $ do let defs = I.programDefs p - funs = map snd $ filter isFun defs + funs = filter isFun defs funNames = map fst $ filter isFun defs oths = filter (not . isFun) defs populateGlobalScope defs - funsWithoutLambdas <- mapM liftLambdas funs + funsWithLiftedBodies <- mapM liftLambdas' funs liftedLambdas <- gets lifted traceM "finished Lifting" - return $ p { I.programDefs = (liftedLambdas ++ (zip funNames funsWithoutLambdas) ++ oths) } + return $ p { I.programDefs = (oths ++ liftedLambdas ++ funsWithLiftedBodies) } where isFun (_, I.Lambda _ _ _) = True isFun _ = False From b0430a295bbefab3ea470c1c0333c2f54a3d2ab4 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 12 Nov 2021 02:55:37 -0500 Subject: [PATCH 07/40] Don't treat 'named' lambdas specially --- src/IR/LambdaLift.hs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 49ada28e..78031939 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -21,7 +21,7 @@ import Prettyprinter import qualified Data.Set as S import qualified Data.Map as M import Data.List (intercalate) -import Data.Maybe (isJust) +import Data.Maybe (isJust, fromJust) data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId , currentScope :: S.Set I.VarId @@ -82,6 +82,7 @@ liftLambdas' (v, lam@(I.Lambda _ _ t)) = do liftedBody <- liftLambdas body return $ (v, foldl (\lam' v' -> (I.Lambda v' lam' t)) liftedBody vs) +{- liftLetBinding :: (I.VarId, I.Expr Poly.Type) -> LiftFn (Maybe (I.VarId, I.Expr Poly.Type)) liftLetBinding (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam @@ -99,6 +100,7 @@ liftLetBinding (v, e) = do traceM "non-lambda let" liftedBody <- liftLambdas e return $ Just (v, liftedBody) +-} liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v _) = do @@ -127,13 +129,13 @@ liftLambdas lam@(I.Lambda _ _ t) = do modify $ \st -> st { currentScope = currentScope oldCtx, currentFrees = currentScope oldCtx } return (foldl (\app v -> I.App app (I.Var v t) t) (I.Var (I.VarId (Identifier ("anon" ++ (show freshNum)))) t) (S.toList lamFrees)) liftLambdas lbs@(I.Let bs e t) = do - let vs = map (\(Just v, _) -> v) bs + let vs = map (\(v, _) -> v) bs exprs = map (\(_, e') -> e') bs traceM "Let" - mapM_ addCurrentScope vs - liftedBindings <- mapM liftLetBinding (zip vs exprs) + mapM_ addCurrentScope (map fromJust vs) + liftedLetBodies <- mapM liftLambdas (exprs) liftedExpr <- liftLambdas e - return $ I.Let (((map (\(Just (v, e')) -> (Just v, e'))) . (filter isJust)) liftedBindings) (liftedExpr) t + return $ I.Let (zip vs liftedLetBodies) (liftedExpr) t liftLambdas n = return n liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) From 7d1ce3291ae0e346219923bf127417cb8861a40d Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 12 Nov 2021 03:02:01 -0500 Subject: [PATCH 08/40] Format code --- src/IR/LambdaLift.hs | 125 ++++++++++++++++++++++++++----------------- 1 file changed, 77 insertions(+), 48 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 78031939..85371f30 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -2,34 +2,37 @@ module IR.LambdaLift where import qualified Common.Compiler as Compiler -import Common.Identifiers +import Common.Identifiers import qualified IR.IR as I import qualified IR.Types.Poly as Poly -import Debug.Trace import Control.Monad.Except ( MonadError(..) ) import Control.Monad.State.Lazy ( MonadState , StateT(..) , evalStateT - , gets , get + , gets , modify ) +import Debug.Trace -import Prettyprinter -import qualified Data.Set as S -import qualified Data.Map as M -import Data.List (intercalate) -import Data.Maybe (isJust, fromJust) - -data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId - , currentScope :: S.Set I.VarId - , currentFrees :: S.Set I.VarId - , lifted :: [(I.VarId, I.Expr Poly.Type)] - , toAdjust :: M.Map I.VarId I.VarId - , anonCount :: Int - } +import Data.List ( intercalate ) +import qualified Data.Map as M +import Data.Maybe ( fromJust + , isJust + ) +import qualified Data.Set as S +import Prettyprinter + +data LiftCtx = LiftCtx + { globalScope :: S.Set I.VarId + , currentScope :: S.Set I.VarId + , currentFrees :: S.Set I.VarId + , lifted :: [(I.VarId, I.Expr Poly.Type)] + , toAdjust :: M.Map I.VarId I.VarId + , anonCount :: Int + } newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) deriving Functor via (StateT LiftCtx Compiler.Pass) @@ -40,18 +43,27 @@ newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass) runLiftFn :: LiftFn a -> Compiler.Pass a -runLiftFn (LiftFn m) = evalStateT m LiftCtx { globalScope = S.empty, currentScope = S.empty, currentFrees = S.empty, lifted = [], toAdjust = M.empty, anonCount = 0 } +runLiftFn (LiftFn m) = evalStateT + m + LiftCtx { globalScope = S.empty + , currentScope = S.empty + , currentFrees = S.empty + , lifted = [] + , toAdjust = M.empty + , anonCount = 0 + } populateGlobalScope :: [(I.VarId, I.Expr Poly.Type)] -> LiftFn () populateGlobalScope defs = do - let globalNames = map (\(v, _) -> v) defs + let globalNames = map fst defs modify $ \st -> st { globalScope = S.fromList globalNames } inCurrentScope :: I.VarId -> LiftFn Bool inCurrentScope v = S.member v <$> gets currentScope addCurrentScope :: I.VarId -> LiftFn () -addCurrentScope v = modify $ \st -> st { currentScope = S.insert v $ currentScope st } +addCurrentScope v = + modify $ \st -> st { currentScope = S.insert v $ currentScope st } getFresh :: LiftFn Int getFresh = do @@ -60,27 +72,34 @@ getFresh = do return curCount addLifted :: String -> I.Expr Poly.Type -> LiftFn () -addLifted name lam = modify $ \st -> st { lifted = ((I.VarId (Identifier name)), lam) : lifted st } +addLifted name lam = + modify $ \st -> st { lifted = (I.VarId (Identifier name), lam) : lifted st } addFreeVar :: I.VarId -> LiftFn () -addFreeVar v = modify $ \st -> st { currentFrees = S.insert v $ currentFrees st } +addFreeVar v = + modify $ \st -> st { currentFrees = S.insert v $ currentFrees st } newScope :: [I.VarId] -> LiftFn () -newScope vs = modify $ \st -> st { currentScope = S.union (globalScope st) (S.fromList vs), currentFrees = S.empty } +newScope vs = modify $ \st -> st + { currentScope = S.union (globalScope st) (S.fromList vs) + , currentFrees = S.empty + } -makeLiftedLambda :: [I.Binder] -> I.Expr Poly.Type -> Poly.Type -> LiftFn (I.Expr Poly.Type) +makeLiftedLambda + :: [I.Binder] -> I.Expr Poly.Type -> Poly.Type -> LiftFn (I.Expr Poly.Type) makeLiftedLambda [] body _ = return body makeLiftedLambda vs body t = do liftedBody <- makeLiftedLambda (tail vs) body t traceM (show vs) return (I.Lambda (head vs) liftedBody t) -liftLambdas' :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) +liftLambdas' + :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) liftLambdas' (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam newScope $ map (\(Just vi) -> vi) vs liftedBody <- liftLambdas body - return $ (v, foldl (\lam' v' -> (I.Lambda v' lam' t)) liftedBody vs) + return (v, foldl (\lam' v' -> I.Lambda v' lam' t) liftedBody vs) {- liftLetBinding :: (I.VarId, I.Expr Poly.Type) -> LiftFn (Maybe (I.VarId, I.Expr Poly.Type)) @@ -105,10 +124,12 @@ liftLetBinding (v, e) = do liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v _) = do isNotFree <- inCurrentScope v - if isNotFree then return n - else do addFreeVar v - return n -liftLambdas (I.App e1 e2 t) = do + if isNotFree + then return n + else do + addFreeVar v + return n +liftLambdas (I.App e1 e2 t) = do liftedE1 <- liftLambdas e1 liftedE2 <- liftLambdas e2 return $ I.App liftedE1 liftedE2 t @@ -122,36 +143,44 @@ liftLambdas lam@(I.Lambda _ _ t) = do oldCtx <- get newScope $ map (\(Just v) -> v) vs liftedLamBody <- liftLambdas body - lamFrees <- gets currentFrees - liftedLam <- makeLiftedLambda (map (Just) (S.toList lamFrees) ++ vs) liftedLamBody t + lamFrees <- gets currentFrees + liftedLam <- makeLiftedLambda (map Just (S.toList lamFrees) ++ vs) + liftedLamBody + t freshNum <- getFresh - addLifted ("anon" ++ (show freshNum)) liftedLam - modify $ \st -> st { currentScope = currentScope oldCtx, currentFrees = currentScope oldCtx } - return (foldl (\app v -> I.App app (I.Var v t) t) (I.Var (I.VarId (Identifier ("anon" ++ (show freshNum)))) t) (S.toList lamFrees)) + addLifted ("anon" ++ show freshNum) liftedLam + modify $ \st -> st { currentScope = currentScope oldCtx + , currentFrees = currentScope oldCtx + } + return + (foldl (\app v -> I.App app (I.Var v t) t) + (I.Var (I.VarId (Identifier ("anon" ++ show freshNum))) t) + (S.toList lamFrees) + ) liftLambdas lbs@(I.Let bs e t) = do - let vs = map (\(v, _) -> v) bs - exprs = map (\(_, e') -> e') bs + let vs = map fst bs + exprs = map snd bs traceM "Let" - mapM_ addCurrentScope (map fromJust vs) - liftedLetBodies <- mapM liftLambdas (exprs) - liftedExpr <- liftLambdas e - return $ I.Let (zip vs liftedLetBodies) (liftedExpr) t -liftLambdas n = return n + mapM_ (addCurrentScope . fromJust) vs + liftedLetBodies <- mapM liftLambdas exprs + liftedExpr <- liftLambdas e + return $ I.Let (zip vs liftedLetBodies) liftedExpr t +liftLambdas n = return n -liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) +liftProgramLambdas + :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) liftProgramLambdas p = runLiftFn $ do let defs = I.programDefs p funs = filter isFun defs - funNames = map fst $ filter isFun defs oths = filter (not . isFun) defs populateGlobalScope defs funsWithLiftedBodies <- mapM liftLambdas' funs - liftedLambdas <- gets lifted + liftedLambdas <- gets lifted traceM "finished Lifting" - return $ p { I.programDefs = (oths ++ liftedLambdas ++ funsWithLiftedBodies) } - where - isFun (_, I.Lambda _ _ _) = True - isFun _ = False + return $ p { I.programDefs = oths ++ liftedLambdas ++ funsWithLiftedBodies } + where + isFun (_, I.Lambda{}) = True + isFun _ = False {- liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) From 1e2494139381fe85a0cc250c719fe5d359483ae7 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 12 Nov 2021 15:49:50 -0500 Subject: [PATCH 09/40] Lift prims --- src/IR/LambdaLift.hs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 85371f30..a48260fe 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -25,6 +25,8 @@ import Data.Maybe ( fromJust import qualified Data.Set as S import Prettyprinter +import GHC.Stack + data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId , currentScope :: S.Set I.VarId @@ -121,7 +123,7 @@ liftLetBinding (v, e) = do return $ Just (v, liftedBody) -} -liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) +liftLambdas :: (HasCallStack) => I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v _) = do isNotFree <- inCurrentScope v if isNotFree @@ -133,9 +135,8 @@ liftLambdas (I.App e1 e2 t) = do liftedE1 <- liftLambdas e1 liftedE2 <- liftLambdas e2 return $ I.App liftedE1 liftedE2 t -liftLambdas a@(I.Prim p [l, r] t) = do - liftedExprs <- mapM liftLambdas [l, r] - traceM (show $ pretty l) -- what is happening? +liftLambdas a@(I.Prim p exprs t) = do + liftedExprs <- mapM liftLambdas exprs return $ I.Prim p liftedExprs t liftLambdas lam@(I.Lambda _ _ t) = do let (vs, body) = I.collectLambda lam From ddc46cb313ab26a6000ad5ad314747158cb90f76 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Mon, 15 Nov 2021 00:56:20 -0500 Subject: [PATCH 10/40] Remove unneeded code --- src/IR/LambdaLift.hs | 65 ++++---------------------------------------- 1 file changed, 5 insertions(+), 60 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index a48260fe..9ffa96ca 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -17,22 +17,15 @@ import Control.Monad.State.Lazy ( MonadState ) import Debug.Trace -import Data.List ( intercalate ) -import qualified Data.Map as M -import Data.Maybe ( fromJust - , isJust - ) +import Data.Maybe ( fromJust ) import qualified Data.Set as S -import Prettyprinter -import GHC.Stack data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId , currentScope :: S.Set I.VarId , currentFrees :: S.Set I.VarId , lifted :: [(I.VarId, I.Expr Poly.Type)] - , toAdjust :: M.Map I.VarId I.VarId , anonCount :: Int } @@ -51,7 +44,6 @@ runLiftFn (LiftFn m) = evalStateT , currentScope = S.empty , currentFrees = S.empty , lifted = [] - , toAdjust = M.empty , anonCount = 0 } @@ -102,28 +94,9 @@ liftLambdas' (v, lam@(I.Lambda _ _ t)) = do newScope $ map (\(Just vi) -> vi) vs liftedBody <- liftLambdas body return (v, foldl (\lam' v' -> I.Lambda v' lam' t) liftedBody vs) +liftLambdas' _ = error "Expected top-level lambda binding" -{- -liftLetBinding :: (I.VarId, I.Expr Poly.Type) -> LiftFn (Maybe (I.VarId, I.Expr Poly.Type)) -liftLetBinding (v, lam@(I.Lambda _ _ t)) = do - let (vs, body) = I.collectLambda lam - traceM "lamda let" - oldCtx <- get - newScope $ map (\(Just vi) -> vi) vs - liftedLamBody <- liftLambdas body - lamFrees <- gets currentFrees - liftedLam <- makeLiftedLambda (map Just (S.toList lamFrees) ++ vs) liftedLamBody t - freshNum <- getFresh - addLifted (show v ++ "_lifted_" ++ show freshNum) liftedLam - modify $ \st -> st { currentScope = currentScope oldCtx, currentFrees = currentScope oldCtx } - return Nothing -liftLetBinding (v, e) = do - traceM "non-lambda let" - liftedBody <- liftLambdas e - return $ Just (v, liftedBody) --} - -liftLambdas :: (HasCallStack) => I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) +liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v _) = do isNotFree <- inCurrentScope v if isNotFree @@ -135,7 +108,7 @@ liftLambdas (I.App e1 e2 t) = do liftedE1 <- liftLambdas e1 liftedE2 <- liftLambdas e2 return $ I.App liftedE1 liftedE2 t -liftLambdas a@(I.Prim p exprs t) = do +liftLambdas (I.Prim p exprs t) = do liftedExprs <- mapM liftLambdas exprs return $ I.Prim p liftedExprs t liftLambdas lam@(I.Lambda _ _ t) = do @@ -158,7 +131,7 @@ liftLambdas lam@(I.Lambda _ _ t) = do (I.Var (I.VarId (Identifier ("anon" ++ show freshNum))) t) (S.toList lamFrees) ) -liftLambdas lbs@(I.Let bs e t) = do +liftLambdas (I.Let bs e t) = do let vs = map fst bs exprs = map snd bs traceM "Let" @@ -182,31 +155,3 @@ liftProgramLambdas p = runLiftFn $ do where isFun (_, I.Lambda{}) = True isFun _ = False - {- -liftProgramLambdas - :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) -liftProgramLambdas p = do - let defs = I.programDefs p - globalScope = map (\(v, _) -> v) defs - funs = filter isFun defs - freeVars = map (getFrees (S.fromList globalScope) (S.fromList globalScope)) (map snd funs) - traceM (show $ zip (map (\(v, _) -> v) funs) freeVars) - return p - where - isFun (_, I.Lambda _ _ _) = True - isFun _ = False - getFrees scp gs (I.Var v _) = if S.member (v) scp then [] else [show v] - getFrees scp gs (I.App e1 e2 _) = getFrees scp gs e1 ++ getFrees scp gs e2 - getFrees scp gs (I.Let binds e _) = let newScp = foldl (\s (Just v, _) -> S.insert (v) s) scp binds in - (concatMap (getLetFrees newScp gs) binds) ++ getFrees newScp gs e - getFrees scp gs lam@(I.Lambda _ _ _) = let (vs, body) = I.collectLambda lam - newScp = foldl (\s (Just v) -> S.insert v s) gs vs in - trace (intercalate "\n" (getFrees newScp gs body) ++ "\n--") [] - getFrees scp gs (I.Match _ _ _ _) = undefined - getFrees scp gs (I.Prim _ es _) = concatMap (getFrees scp gs) es - getFrees _ _ _ = [] - getLetFrees scp gs (Just v, lam@(I.Lambda _ _ _ )) = let (vs, body) = I.collectLambda lam - newScp = foldl (\s (Just v) -> S.insert v s) gs vs in - trace (show v ++ ":\n" ++ (intercalate "\n" (getFrees newScp gs body) ++ "\n--")) [] - getLetFrees scp gs (Just _, e) = getFrees scp gs e --} From a5430f6ab86977b9e68886e29ff84763ce4e3854 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 16 Nov 2021 13:53:27 -0500 Subject: [PATCH 11/40] Tweak coding style --- src/IR/LambdaLift.hs | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 9ffa96ca..cd939743 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -14,10 +14,11 @@ import Control.Monad.State.Lazy ( MonadState , get , gets , modify + , when ) import Debug.Trace -import Data.Maybe ( fromJust ) +import Data.Maybe ( catMaybes ) import qualified Data.Set as S @@ -59,11 +60,11 @@ addCurrentScope :: I.VarId -> LiftFn () addCurrentScope v = modify $ \st -> st { currentScope = S.insert v $ currentScope st } -getFresh :: LiftFn Int +getFresh :: LiftFn String getFresh = do curCount <- gets anonCount modify $ \st -> st { anonCount = anonCount st + 1 } - return curCount + return $ "anon" ++ show curCount addLifted :: String -> I.Expr Poly.Type -> LiftFn () addLifted name lam = @@ -91,7 +92,7 @@ liftLambdas' :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) liftLambdas' (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam - newScope $ map (\(Just vi) -> vi) vs + newScope $ catMaybes vs liftedBody <- liftLambdas body return (v, foldl (\lam' v' -> I.Lambda v' lam' t) liftedBody vs) liftLambdas' _ = error "Expected top-level lambda binding" @@ -99,11 +100,8 @@ liftLambdas' _ = error "Expected top-level lambda binding" liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v _) = do isNotFree <- inCurrentScope v - if isNotFree - then return n - else do - addFreeVar v - return n + when isNotFree $ addFreeVar v + return n liftLambdas (I.App e1 e2 t) = do liftedE1 <- liftLambdas e1 liftedE2 <- liftLambdas e2 @@ -115,27 +113,26 @@ liftLambdas lam@(I.Lambda _ _ t) = do let (vs, body) = I.collectLambda lam traceM "Lambda" oldCtx <- get - newScope $ map (\(Just v) -> v) vs + newScope $ catMaybes vs liftedLamBody <- liftLambdas body lamFrees <- gets currentFrees liftedLam <- makeLiftedLambda (map Just (S.toList lamFrees) ++ vs) liftedLamBody t - freshNum <- getFresh - addLifted ("anon" ++ show freshNum) liftedLam + freshName <- getFresh + addLifted freshName liftedLam modify $ \st -> st { currentScope = currentScope oldCtx , currentFrees = currentScope oldCtx } - return - (foldl (\app v -> I.App app (I.Var v t) t) - (I.Var (I.VarId (Identifier ("anon" ++ show freshNum))) t) - (S.toList lamFrees) - ) + return $ foldl (\app v -> I.App app (I.Var v t) t) + (I.Var (I.VarId (Identifier freshName)) t) + (S.toList lamFrees) + liftLambdas (I.Let bs e t) = do let vs = map fst bs exprs = map snd bs traceM "Let" - mapM_ (addCurrentScope . fromJust) vs + mapM_ addCurrentScope (catMaybes vs) liftedLetBodies <- mapM liftLambdas exprs liftedExpr <- liftLambdas e return $ I.Let (zip vs liftedLetBodies) liftedExpr t From 82943153ded70efeee1dc85130e403e53fd0dd47 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 16 Nov 2021 17:29:23 -0500 Subject: [PATCH 12/40] Fix free var detection bug --- src/IR/LambdaLift.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index cd939743..e30080e1 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -14,7 +14,7 @@ import Control.Monad.State.Lazy ( MonadState , get , gets , modify - , when + , unless ) import Debug.Trace @@ -99,8 +99,8 @@ liftLambdas' _ = error "Expected top-level lambda binding" liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v _) = do - isNotFree <- inCurrentScope v - when isNotFree $ addFreeVar v + inScope <- inCurrentScope v + unless inScope $ addFreeVar v return n liftLambdas (I.App e1 e2 t) = do liftedE1 <- liftLambdas e1 From 1092761154e001f8fc970a79f017e7d08c319f13 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 16 Nov 2021 17:56:00 -0500 Subject: [PATCH 13/40] Exploit very epic laziness to descend into lambda bodies --- src/IR/LambdaLift.hs | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index e30080e1..ab4ef183 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -11,7 +11,6 @@ import Control.Monad.Except ( MonadError(..) ) import Control.Monad.State.Lazy ( MonadState , StateT(..) , evalStateT - , get , gets , modify , unless @@ -97,6 +96,16 @@ liftLambdas' (v, lam@(I.Lambda _ _ t)) = do return (v, foldl (\lam' v' -> I.Lambda v' lam' t) liftedBody vs) liftLambdas' _ = error "Expected top-level lambda binding" +descend :: LiftFn a -> LiftFn (a, S.Set VarId) +descend lb = do + savedScope <- gets currentScope + savedFrees <- gets currentFrees + liftedLamBody <- lb + lamFrees <- gets currentFrees + modify $ \st -> + st { currentScope = savedScope, currentFrees = savedFrees } + return (liftedLamBody, lamFrees) + liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v _) = do inScope <- inCurrentScope v @@ -112,18 +121,13 @@ liftLambdas (I.Prim p exprs t) = do liftLambdas lam@(I.Lambda _ _ t) = do let (vs, body) = I.collectLambda lam traceM "Lambda" - oldCtx <- get - newScope $ catMaybes vs - liftedLamBody <- liftLambdas body - lamFrees <- gets currentFrees - liftedLam <- makeLiftedLambda (map Just (S.toList lamFrees) ++ vs) - liftedLamBody - t + (liftedLamBody, lamFrees) <- + descend $ newScope (catMaybes vs) >> liftLambdas body + liftedLam <- makeLiftedLambda (map Just (S.toList lamFrees) ++ vs) + liftedLamBody + t freshName <- getFresh addLifted freshName liftedLam - modify $ \st -> st { currentScope = currentScope oldCtx - , currentFrees = currentScope oldCtx - } return $ foldl (\app v -> I.App app (I.Var v t) t) (I.Var (I.VarId (Identifier freshName)) t) (S.toList lamFrees) From 79802e41b564f91644a751e898495c237ed17fd8 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 16 Nov 2021 18:18:28 -0500 Subject: [PATCH 14/40] Add compiler option for dumping lifted IR --- app/Main.hs | 17 +++++++++++------ src/IR/LambdaLift.hs | 8 +------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/app/Main.hs b/app/Main.hs index 0b830f15..18b844cd 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -32,6 +32,7 @@ data Mode | DumpAST -- ^ AST before operator parsing | DumpASTP -- ^ AST after operators are parsed | DumpIR -- ^ Intermediate representation + | DumpIRLifted -- ^ Intermediate representation with lifted lambdas | GenerateC -- ^ Generate C backend deriving (Eq, Show) @@ -69,6 +70,10 @@ optionDescriptions = ["dump-ir"] (NoArg (\opt -> return opt { optMode = DumpIR })) "Print the IR" + , Option "" + ["dump-lifted-ir"] + (NoArg (\opt -> return opt { optMode = DumpIRLifted })) + "Print the IR with lifted lambdas" , Option "" ["generate-c"] (NoArg (\opt -> return opt { optMode = GenerateC })) @@ -113,7 +118,7 @@ main = do when (optMode opts == DumpAST) $ putStrLn (spaghetti ast) >> exitSuccess ast' <- doPass $ Front.desugarAst ast - when (optMode opts == DumpASTP) $ putStrLn (spaghetti ast') >> exitSuccess + when (optMode opts == DumpASTP) $ putStrLn (spaghetti ast') >> exitSuccess () <- doPass $ Front.checkAst ast' @@ -121,15 +126,15 @@ main = do when (optMode opts == DumpIR) $ putStrLn (spaghetti irA) >> exitSuccess - irC <- doPass $ IR.inferTypes irA + irC <- doPass $ IR.inferTypes irA - irP <- doPass $ IR.instantiateClasses irC + irP <- doPass $ IR.instantiateClasses irC - irY <- doPass $ IR.yieldAbstraction irP + irY <- doPass $ IR.yieldAbstraction irP - irL <- doPass $ IR.lambdaLift irY + irL <- doPass $ IR.lambdaLift irY - when (True) $ putStrLn (spaghetti irL) >> exitSuccess + when (optMode opts == DumpIRLifted) $ putStrLn (spaghetti irL) >> exitSuccess irI <- doPass $ IR.defunctionalize irL diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index ab4ef183..b5314280 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -15,7 +15,6 @@ import Control.Monad.State.Lazy ( MonadState , modify , unless ) -import Debug.Trace import Data.Maybe ( catMaybes ) import qualified Data.Set as S @@ -84,7 +83,6 @@ makeLiftedLambda makeLiftedLambda [] body _ = return body makeLiftedLambda vs body t = do liftedBody <- makeLiftedLambda (tail vs) body t - traceM (show vs) return (I.Lambda (head vs) liftedBody t) liftLambdas' @@ -102,8 +100,7 @@ descend lb = do savedFrees <- gets currentFrees liftedLamBody <- lb lamFrees <- gets currentFrees - modify $ \st -> - st { currentScope = savedScope, currentFrees = savedFrees } + modify $ \st -> st { currentScope = savedScope, currentFrees = savedFrees } return (liftedLamBody, lamFrees) liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) @@ -120,7 +117,6 @@ liftLambdas (I.Prim p exprs t) = do return $ I.Prim p liftedExprs t liftLambdas lam@(I.Lambda _ _ t) = do let (vs, body) = I.collectLambda lam - traceM "Lambda" (liftedLamBody, lamFrees) <- descend $ newScope (catMaybes vs) >> liftLambdas body liftedLam <- makeLiftedLambda (map Just (S.toList lamFrees) ++ vs) @@ -135,7 +131,6 @@ liftLambdas lam@(I.Lambda _ _ t) = do liftLambdas (I.Let bs e t) = do let vs = map fst bs exprs = map snd bs - traceM "Let" mapM_ addCurrentScope (catMaybes vs) liftedLetBodies <- mapM liftLambdas exprs liftedExpr <- liftLambdas e @@ -151,7 +146,6 @@ liftProgramLambdas p = runLiftFn $ do populateGlobalScope defs funsWithLiftedBodies <- mapM liftLambdas' funs liftedLambdas <- gets lifted - traceM "finished Lifting" return $ p { I.programDefs = oths ++ liftedLambdas ++ funsWithLiftedBodies } where isFun (_, I.Lambda{}) = True From 0da1a0b084f141955d70539787abbbe77756c4ee Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Thu, 25 Nov 2021 20:29:22 -0500 Subject: [PATCH 15/40] Reconstruct lambdas with correct types --- src/IR/LambdaLift.hs | 88 ++++++++++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 28 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index b5314280..de88feac 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -16,6 +16,8 @@ import Control.Monad.State.Lazy ( MonadState , unless ) +import qualified Data.Bifunctor as B +import qualified Data.Map as M import Data.Maybe ( catMaybes ) import qualified Data.Set as S @@ -23,7 +25,7 @@ import qualified Data.Set as S data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId , currentScope :: S.Set I.VarId - , currentFrees :: S.Set I.VarId + , freeTypes :: M.Map I.VarId Poly.Type , lifted :: [(I.VarId, I.Expr Poly.Type)] , anonCount :: Int } @@ -36,12 +38,26 @@ newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) deriving (MonadError Compiler.Error) via (StateT LiftCtx Compiler.Pass) deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass) +exprType :: I.Expr Poly.Type -> Poly.Type +exprType (I.Lambda _ _ t) = t +exprType (I.App _ _ t) = t +exprType (I.Lit _ t ) = t +exprType (I.Var _ t ) = t +exprType (I.Prim _ _ t ) = t +exprType (I.Let _ _ t ) = t + +zipArgsWithArrow :: [Binder] -> Poly.Type -> [(Binder, Poly.Type)] +zipArgsWithArrow (b : bs) (Poly.TBuiltin (Poly.Arrow t ts)) = + (b, t) : zipArgsWithArrow bs ts +zipArgsWithArrow [] _ = [] +zipArgsWithArrow _ _ = error "Expected longer arrow type" + runLiftFn :: LiftFn a -> Compiler.Pass a runLiftFn (LiftFn m) = evalStateT m LiftCtx { globalScope = S.empty , currentScope = S.empty - , currentFrees = S.empty + , freeTypes = M.empty , lifted = [] , anonCount = 0 } @@ -68,45 +84,54 @@ addLifted :: String -> I.Expr Poly.Type -> LiftFn () addLifted name lam = modify $ \st -> st { lifted = (I.VarId (Identifier name), lam) : lifted st } -addFreeVar :: I.VarId -> LiftFn () -addFreeVar v = - modify $ \st -> st { currentFrees = S.insert v $ currentFrees st } +addFreeVar :: I.VarId -> Poly.Type -> LiftFn () +addFreeVar v t = modify $ \st -> st { freeTypes = M.insert v t $ freeTypes st } newScope :: [I.VarId] -> LiftFn () newScope vs = modify $ \st -> st { currentScope = S.union (globalScope st) (S.fromList vs) - , currentFrees = S.empty + , freeTypes = M.empty } +makeArrow :: Poly.Type -> I.Expr Poly.Type -> Poly.Type +makeArrow lhsType rhsExpr = + Poly.TBuiltin $ Poly.Arrow lhsType (exprType rhsExpr) + makeLiftedLambda - :: [I.Binder] -> I.Expr Poly.Type -> Poly.Type -> LiftFn (I.Expr Poly.Type) -makeLiftedLambda [] body _ = return body -makeLiftedLambda vs body t = do - liftedBody <- makeLiftedLambda (tail vs) body t - return (I.Lambda (head vs) liftedBody t) + :: [(I.Binder, Poly.Type)] -> I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) +makeLiftedLambda [] body = return body +makeLiftedLambda ((v, t) : vs) body = do + liftedBody <- makeLiftedLambda vs body + return (I.Lambda v liftedBody $ makeArrow t liftedBody) liftLambdas' :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) liftLambdas' (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam + vs' = zipArgsWithArrow vs t newScope $ catMaybes vs liftedBody <- liftLambdas body - return (v, foldl (\lam' v' -> I.Lambda v' lam' t) liftedBody vs) + return + ( v + , foldl (\lam' (v', t') -> I.Lambda v' lam' $ makeArrow t' lam') + liftedBody + vs' + ) liftLambdas' _ = error "Expected top-level lambda binding" -descend :: LiftFn a -> LiftFn (a, S.Set VarId) +descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) descend lb = do - savedScope <- gets currentScope - savedFrees <- gets currentFrees - liftedLamBody <- lb - lamFrees <- gets currentFrees - modify $ \st -> st { currentScope = savedScope, currentFrees = savedFrees } - return (liftedLamBody, lamFrees) + savedScope <- gets currentScope + savedFreeTypes <- gets freeTypes + liftedLamBody <- lb + lamFreeTypes <- gets freeTypes + modify $ \st -> st { currentScope = savedScope, freeTypes = savedFreeTypes } + return (liftedLamBody, lamFreeTypes) liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) -liftLambdas n@(I.Var v _) = do +liftLambdas n@(I.Var v t) = do inScope <- inCurrentScope v - unless inScope $ addFreeVar v + unless inScope $ addFreeVar v t return n liftLambdas (I.App e1 e2 t) = do liftedE1 <- liftLambdas e1 @@ -117,16 +142,23 @@ liftLambdas (I.Prim p exprs t) = do return $ I.Prim p liftedExprs t liftLambdas lam@(I.Lambda _ _ t) = do let (vs, body) = I.collectLambda lam - (liftedLamBody, lamFrees) <- + vs' = zipArgsWithArrow vs t + (liftedLamBody, lamFreeTypes) <- descend $ newScope (catMaybes vs) >> liftLambdas body - liftedLam <- makeLiftedLambda (map Just (S.toList lamFrees) ++ vs) - liftedLamBody - t + liftedLam <- makeLiftedLambda + (map (B.first Just) (M.toList lamFreeTypes) ++ vs') + liftedLamBody freshName <- getFresh addLifted freshName liftedLam - return $ foldl (\app v -> I.App app (I.Var v t) t) - (I.Var (I.VarId (Identifier freshName)) t) - (S.toList lamFrees) + return $ applyFreesToLambda + (I.Var (I.VarId (Identifier freshName)) (exprType liftedLam)) + (M.toList lamFreeTypes) + (exprType liftedLam) + where + applyFreesToLambda app ((v', t') : vs) (Poly.TBuiltin (Poly.Arrow tl tr)) = + applyFreesToLambda (I.App app (I.Var v' t') tl) vs tr + applyFreesToLambda app [] _ = app + applyFreesToLambda _ _ _ = error "Expected longer arrow type" liftLambdas (I.Let bs e t) = do let vs = map fst bs From 6db3e97baea33ebb7a01cf20a27452e45390e8f1 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Thu, 25 Nov 2021 21:25:34 -0500 Subject: [PATCH 16/40] Fix lambda free-var application types --- src/IR/LambdaLift.hs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index de88feac..8434ffc5 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -155,8 +155,9 @@ liftLambdas lam@(I.Lambda _ _ t) = do (M.toList lamFreeTypes) (exprType liftedLam) where - applyFreesToLambda app ((v', t') : vs) (Poly.TBuiltin (Poly.Arrow tl tr)) = - applyFreesToLambda (I.App app (I.Var v' t') tl) vs tr + applyFreesToLambda app ((v', t') : vs) (Poly.TBuiltin (Poly.Arrow _ tr)) = + -- TODO(hans): We could assert t' == tl + applyFreesToLambda (I.App app (I.Var v' t') tr) vs tr applyFreesToLambda app [] _ = app applyFreesToLambda _ _ _ = error "Expected longer arrow type" From 1582b72a18576e16dacb69db496afb90ce13d7e1 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 26 Nov 2021 00:54:01 -0500 Subject: [PATCH 17/40] Add lifting for match expressions --- src/IR/LambdaLift.hs | 45 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 8434ffc5..593e8722 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -17,6 +17,7 @@ import Control.Monad.State.Lazy ( MonadState ) import qualified Data.Bifunctor as B +import qualified Data.Foldable as F import qualified Data.Map as M import Data.Maybe ( catMaybes ) import qualified Data.Set as S @@ -39,12 +40,14 @@ newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass) exprType :: I.Expr Poly.Type -> Poly.Type -exprType (I.Lambda _ _ t) = t -exprType (I.App _ _ t) = t -exprType (I.Lit _ t ) = t -exprType (I.Var _ t ) = t -exprType (I.Prim _ _ t ) = t -exprType (I.Let _ _ t ) = t +exprType (I.Lambda _ _ t ) = t +exprType (I.App _ _ t ) = t +exprType (I.Lit _ t ) = t +exprType (I.Var _ t ) = t +exprType (I.Prim _ _ t ) = t +exprType (I.Let _ _ t ) = t +exprType (I.Data _ t ) = t +exprType (I.Match _ _ _ t) = t zipArgsWithArrow :: [Binder] -> Poly.Type -> [(Binder, Poly.Type)] zipArgsWithArrow (b : bs) (Poly.TBuiltin (Poly.Arrow t ts)) = @@ -120,13 +123,31 @@ liftLambdas' (v, lam@(I.Lambda _ _ t)) = do liftLambdas' _ = error "Expected top-level lambda binding" descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) -descend lb = do +descend body = do savedScope <- gets currentScope savedFreeTypes <- gets freeTypes - liftedLamBody <- lb - lamFreeTypes <- gets freeTypes + liftedBody <- body + freeTypesBody <- gets freeTypes modify $ \st -> st { currentScope = savedScope, freeTypes = savedFreeTypes } - return (liftedLamBody, lamFreeTypes) + return (liftedBody, freeTypesBody) + +liftLambdasInArm + :: Binder -> (I.Alt, I.Expr Poly.Type) -> LiftFn (I.Alt, I.Expr Poly.Type) +liftLambdasInArm sb (I.AltLit l, arm) = do + (liftedArm, armFrees) <- + descend $ F.forM_ sb addCurrentScope >> liftLambdas arm + modify $ \st -> st { freeTypes = armFrees } + return (I.AltLit l, liftedArm) +liftLambdasInArm sb (I.AltDefault, arm) = do + (liftedArm, armFrees) <- + descend $ F.forM_ sb addCurrentScope >> liftLambdas arm + modify $ \st -> st { freeTypes = armFrees } + return (I.AltDefault, liftedArm) +liftLambdasInArm sb (I.AltData d bs, arm) = do + (liftedArm, armFrees) <- + descend $ mapM_ addCurrentScope (catMaybes (sb : bs)) >> liftLambdas arm + modify $ \st -> st { freeTypes = armFrees } + return (I.AltData d bs, liftedArm) liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v t) = do @@ -168,6 +189,10 @@ liftLambdas (I.Let bs e t) = do liftedLetBodies <- mapM liftLambdas exprs liftedExpr <- liftLambdas e return $ I.Let (zip vs liftedLetBodies) liftedExpr t +liftLambdas (I.Match s sb arms t) = do + liftedMatch <- liftLambdas s + liftedArms <- mapM (liftLambdasInArm sb) arms + return $ I.Match liftedMatch sb liftedArms t liftLambdas n = return n liftProgramLambdas From ce5df9d2c96c36309f0493932991ffe6e9804aae Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 26 Nov 2021 01:31:49 -0500 Subject: [PATCH 18/40] Rebase on main and update match code --- src/IR/LambdaLift.hs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 593e8722..d0178098 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -47,7 +47,7 @@ exprType (I.Var _ t ) = t exprType (I.Prim _ _ t ) = t exprType (I.Let _ _ t ) = t exprType (I.Data _ t ) = t -exprType (I.Match _ _ _ t) = t +exprType (I.Match _ _ t) = t zipArgsWithArrow :: [Binder] -> Poly.Type -> [(Binder, Poly.Type)] zipArgsWithArrow (b : bs) (Poly.TBuiltin (Poly.Arrow t ts)) = @@ -132,20 +132,20 @@ descend body = do return (liftedBody, freeTypesBody) liftLambdasInArm - :: Binder -> (I.Alt, I.Expr Poly.Type) -> LiftFn (I.Alt, I.Expr Poly.Type) -liftLambdasInArm sb (I.AltLit l, arm) = do + :: (I.Alt, I.Expr Poly.Type) -> LiftFn (I.Alt, I.Expr Poly.Type) +liftLambdasInArm (I.AltLit l, arm) = do (liftedArm, armFrees) <- - descend $ F.forM_ sb addCurrentScope >> liftLambdas arm + descend $ return () >> liftLambdas arm modify $ \st -> st { freeTypes = armFrees } return (I.AltLit l, liftedArm) -liftLambdasInArm sb (I.AltDefault, arm) = do +liftLambdasInArm (I.AltDefault b, arm) = do (liftedArm, armFrees) <- - descend $ F.forM_ sb addCurrentScope >> liftLambdas arm + descend $ F.forM_ b addCurrentScope >> liftLambdas arm modify $ \st -> st { freeTypes = armFrees } - return (I.AltDefault, liftedArm) -liftLambdasInArm sb (I.AltData d bs, arm) = do + return (I.AltDefault b, liftedArm) +liftLambdasInArm (I.AltData d bs, arm) = do (liftedArm, armFrees) <- - descend $ mapM_ addCurrentScope (catMaybes (sb : bs)) >> liftLambdas arm + descend $ mapM_ addCurrentScope (catMaybes bs) >> liftLambdas arm modify $ \st -> st { freeTypes = armFrees } return (I.AltData d bs, liftedArm) @@ -189,10 +189,10 @@ liftLambdas (I.Let bs e t) = do liftedLetBodies <- mapM liftLambdas exprs liftedExpr <- liftLambdas e return $ I.Let (zip vs liftedLetBodies) liftedExpr t -liftLambdas (I.Match s sb arms t) = do +liftLambdas (I.Match s arms t) = do liftedMatch <- liftLambdas s - liftedArms <- mapM (liftLambdasInArm sb) arms - return $ I.Match liftedMatch sb liftedArms t + liftedArms <- mapM liftLambdasInArm arms + return $ I.Match liftedMatch liftedArms t liftLambdas n = return n liftProgramLambdas From fda7e73531e0dddb67c2c0403123b03282795daa Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 26 Nov 2021 02:04:19 -0500 Subject: [PATCH 19/40] Use makeLiftedLambda and insert some trace --- src/IR/LambdaLift.hs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index d0178098..bf5c413d 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -22,6 +22,8 @@ import qualified Data.Map as M import Data.Maybe ( catMaybes ) import qualified Data.Set as S +import Debug.Trace +import Prettyprinter data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId @@ -114,12 +116,9 @@ liftLambdas' (v, lam@(I.Lambda _ _ t)) = do vs' = zipArgsWithArrow vs t newScope $ catMaybes vs liftedBody <- liftLambdas body - return - ( v - , foldl (\lam' (v', t') -> I.Lambda v' lam' $ makeArrow t' lam') - liftedBody - vs' - ) + traceM (show v ++ "and body type " ++ show (pretty $exprType liftedBody)) + liftedLambda <- makeLiftedLambda vs' liftedBody + return (v, liftedLambda) liftLambdas' _ = error "Expected top-level lambda binding" descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) From 9c72fdf907b551de9544342d5f8d654b92675fd2 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 26 Nov 2021 02:14:01 -0500 Subject: [PATCH 20/40] Use function type signature for top level lambda collection --- src/IR/LambdaLift.hs | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index bf5c413d..0915f7fa 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -22,9 +22,6 @@ import qualified Data.Map as M import Data.Maybe ( catMaybes ) import qualified Data.Set as S -import Debug.Trace -import Prettyprinter - data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId , currentScope :: S.Set I.VarId @@ -42,14 +39,14 @@ newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass) exprType :: I.Expr Poly.Type -> Poly.Type -exprType (I.Lambda _ _ t ) = t -exprType (I.App _ _ t ) = t -exprType (I.Lit _ t ) = t -exprType (I.Var _ t ) = t -exprType (I.Prim _ _ t ) = t -exprType (I.Let _ _ t ) = t -exprType (I.Data _ t ) = t -exprType (I.Match _ _ t) = t +exprType (I.Lambda _ _ t) = t +exprType (I.App _ _ t) = t +exprType (I.Lit _ t ) = t +exprType (I.Var _ t ) = t +exprType (I.Prim _ _ t ) = t +exprType (I.Let _ _ t ) = t +exprType (I.Data _ t ) = t +exprType (I.Match _ _ t ) = t zipArgsWithArrow :: [Binder] -> Poly.Type -> [(Binder, Poly.Type)] zipArgsWithArrow (b : bs) (Poly.TBuiltin (Poly.Arrow t ts)) = @@ -111,14 +108,14 @@ makeLiftedLambda ((v, t) : vs) body = do liftLambdas' :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) -liftLambdas' (v, lam@(I.Lambda _ _ t)) = do +liftLambdas' (v, lam@I.Lambda{}) = do let (vs, body) = I.collectLambda lam - vs' = zipArgsWithArrow vs t newScope $ catMaybes vs liftedBody <- liftLambdas body - traceM (show v ++ "and body type " ++ show (pretty $exprType liftedBody)) - liftedLambda <- makeLiftedLambda vs' liftedBody - return (v, liftedLambda) + return (v, updateBody lam liftedBody) + where + updateBody (I.Lambda v' b' t') b = I.Lambda v' (updateBody b' b) t' + updateBody _ b = b liftLambdas' _ = error "Expected top-level lambda binding" descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) @@ -133,8 +130,7 @@ descend body = do liftLambdasInArm :: (I.Alt, I.Expr Poly.Type) -> LiftFn (I.Alt, I.Expr Poly.Type) liftLambdasInArm (I.AltLit l, arm) = do - (liftedArm, armFrees) <- - descend $ return () >> liftLambdas arm + (liftedArm, armFrees) <- descend $ liftLambdas arm modify $ \st -> st { freeTypes = armFrees } return (I.AltLit l, liftedArm) liftLambdasInArm (I.AltDefault b, arm) = do From f1e86511c7505028e50ff26de692023120bd8487 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 26 Nov 2021 15:06:43 -0500 Subject: [PATCH 21/40] Use makeLiftedLambda in liftLambdas' after regression test fix in main --- src/IR/LambdaLift.hs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 0915f7fa..2da8ebad 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -108,14 +108,13 @@ makeLiftedLambda ((v, t) : vs) body = do liftLambdas' :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) -liftLambdas' (v, lam@I.Lambda{}) = do +liftLambdas' (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam + vs' = zipArgsWithArrow vs t newScope $ catMaybes vs - liftedBody <- liftLambdas body - return (v, updateBody lam liftedBody) - where - updateBody (I.Lambda v' b' t') b = I.Lambda v' (updateBody b' b) t' - updateBody _ b = b + liftedBody <- liftLambdas body + liftedLambda <- makeLiftedLambda vs' liftedBody + return (v, liftedLambda) liftLambdas' _ = error "Expected top-level lambda binding" descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) From 1cf21e5de6e2a2ff0d0158d800fd2e46c2ff80c7 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 26 Nov 2021 15:18:46 -0500 Subject: [PATCH 22/40] Use extract --- src/IR/LambdaLift.hs | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 2da8ebad..b8f1bbbb 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -7,6 +7,7 @@ import qualified IR.IR as I import qualified IR.Types.Poly as Poly +import Control.Comonad ( Comonad(..) ) import Control.Monad.Except ( MonadError(..) ) import Control.Monad.State.Lazy ( MonadState , StateT(..) @@ -38,16 +39,6 @@ newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) deriving (MonadError Compiler.Error) via (StateT LiftCtx Compiler.Pass) deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass) -exprType :: I.Expr Poly.Type -> Poly.Type -exprType (I.Lambda _ _ t) = t -exprType (I.App _ _ t) = t -exprType (I.Lit _ t ) = t -exprType (I.Var _ t ) = t -exprType (I.Prim _ _ t ) = t -exprType (I.Let _ _ t ) = t -exprType (I.Data _ t ) = t -exprType (I.Match _ _ t ) = t - zipArgsWithArrow :: [Binder] -> Poly.Type -> [(Binder, Poly.Type)] zipArgsWithArrow (b : bs) (Poly.TBuiltin (Poly.Arrow t ts)) = (b, t) : zipArgsWithArrow bs ts @@ -97,7 +88,7 @@ newScope vs = modify $ \st -> st makeArrow :: Poly.Type -> I.Expr Poly.Type -> Poly.Type makeArrow lhsType rhsExpr = - Poly.TBuiltin $ Poly.Arrow lhsType (exprType rhsExpr) + Poly.TBuiltin $ Poly.Arrow lhsType (extract rhsExpr) makeLiftedLambda :: [(I.Binder, Poly.Type)] -> I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) @@ -166,9 +157,9 @@ liftLambdas lam@(I.Lambda _ _ t) = do freshName <- getFresh addLifted freshName liftedLam return $ applyFreesToLambda - (I.Var (I.VarId (Identifier freshName)) (exprType liftedLam)) + (I.Var (I.VarId (Identifier freshName)) (extract liftedLam)) (M.toList lamFreeTypes) - (exprType liftedLam) + (extract liftedLam) where applyFreesToLambda app ((v', t') : vs) (Poly.TBuiltin (Poly.Arrow _ tr)) = -- TODO(hans): We could assert t' == tl From c25a20b10be341cf24b616247e9e261685366d9d Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 26 Nov 2021 16:03:16 -0500 Subject: [PATCH 23/40] Rewrite applyFreesToLambda as an epic fold --- src/IR/LambdaLift.hs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index b8f1bbbb..ac7594f8 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -6,6 +6,7 @@ import Common.Identifiers import qualified IR.IR as I import qualified IR.Types.Poly as Poly +import IR.Types.TypeSystem ( dearrow ) import Control.Comonad ( Comonad(..) ) import Control.Monad.Except ( MonadError(..) ) @@ -156,16 +157,13 @@ liftLambdas lam@(I.Lambda _ _ t) = do liftedLamBody freshName <- getFresh addLifted freshName liftedLam - return $ applyFreesToLambda - (I.Var (I.VarId (Identifier freshName)) (extract liftedLam)) - (M.toList lamFreeTypes) - (extract liftedLam) + return $ foldl applyFree + (I.Var (I.VarId (Identifier freshName)) (extract liftedLam)) + (M.toList lamFreeTypes) where - applyFreesToLambda app ((v', t') : vs) (Poly.TBuiltin (Poly.Arrow _ tr)) = - -- TODO(hans): We could assert t' == tl - applyFreesToLambda (I.App app (I.Var v' t') tr) vs tr - applyFreesToLambda app [] _ = app - applyFreesToLambda _ _ _ = error "Expected longer arrow type" + applyFree app (v', t') = case dearrow $ extract app of + Just (_, tr) -> I.App app (I.Var v' t') tr + Nothing -> app liftLambdas (I.Let bs e t) = do let vs = map fst bs From 8516e5d518de34173670bf9d63352c910e4e2563 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Fri, 26 Nov 2021 16:09:23 -0500 Subject: [PATCH 24/40] Don't use catch-all in liftLambdas --- src/IR/LambdaLift.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index ac7594f8..dcb1062f 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -176,7 +176,8 @@ liftLambdas (I.Match s arms t) = do liftedMatch <- liftLambdas s liftedArms <- mapM liftLambdasInArm arms return $ I.Match liftedMatch liftedArms t -liftLambdas n = return n +liftLambdas lit@I.Lit{} = return lit +liftLambdas dat@I.Data{} = return dat liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) From c1fdfaad60b4627556173534f036d7258f3e42af Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sat, 27 Nov 2021 20:59:28 -0500 Subject: [PATCH 25/40] Rename liftLambdas' to liftLambdasTop --- src/IR/LambdaLift.hs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index dcb1062f..3691d25b 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -98,16 +98,16 @@ makeLiftedLambda ((v, t) : vs) body = do liftedBody <- makeLiftedLambda vs body return (I.Lambda v liftedBody $ makeArrow t liftedBody) -liftLambdas' +liftLambdasTop :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) -liftLambdas' (v, lam@(I.Lambda _ _ t)) = do +liftLambdasTop (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam vs' = zipArgsWithArrow vs t newScope $ catMaybes vs liftedBody <- liftLambdas body liftedLambda <- makeLiftedLambda vs' liftedBody return (v, liftedLambda) -liftLambdas' _ = error "Expected top-level lambda binding" +liftLambdasTop _ = error "Expected top-level lambda binding" descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) descend body = do @@ -186,7 +186,7 @@ liftProgramLambdas p = runLiftFn $ do funs = filter isFun defs oths = filter (not . isFun) defs populateGlobalScope defs - funsWithLiftedBodies <- mapM liftLambdas' funs + funsWithLiftedBodies <- mapM liftLambdasTop funs liftedLambdas <- gets lifted return $ p { I.programDefs = oths ++ liftedLambdas ++ funsWithLiftedBodies } where From 6031aa1d3a20e2395dec00397ab79c002ba89a65 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sat, 27 Nov 2021 21:07:09 -0500 Subject: [PATCH 26/40] Use 'arrow' instead of redefining it --- src/IR/LambdaLift.hs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 3691d25b..d1b83701 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -6,7 +6,9 @@ import Common.Identifiers import qualified IR.IR as I import qualified IR.Types.Poly as Poly -import IR.Types.TypeSystem ( dearrow ) +import IR.Types.TypeSystem ( arrow + , dearrow + ) import Control.Comonad ( Comonad(..) ) import Control.Monad.Except ( MonadError(..) ) @@ -87,16 +89,12 @@ newScope vs = modify $ \st -> st , freeTypes = M.empty } -makeArrow :: Poly.Type -> I.Expr Poly.Type -> Poly.Type -makeArrow lhsType rhsExpr = - Poly.TBuiltin $ Poly.Arrow lhsType (extract rhsExpr) - makeLiftedLambda :: [(I.Binder, Poly.Type)] -> I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) makeLiftedLambda [] body = return body makeLiftedLambda ((v, t) : vs) body = do liftedBody <- makeLiftedLambda vs body - return (I.Lambda v liftedBody $ makeArrow t liftedBody) + return (I.Lambda v liftedBody $ arrow t (extract liftedBody)) liftLambdasTop :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) From c1d8e4df06dcf7bc171b7fcfe900c25522a69909 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sat, 27 Nov 2021 21:19:45 -0500 Subject: [PATCH 27/40] Use collectArrow --- src/IR/LambdaLift.hs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index d1b83701..7c92829c 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -7,6 +7,7 @@ import qualified IR.IR as I import qualified IR.Types.Poly as Poly import IR.Types.TypeSystem ( arrow + , collectArrow , dearrow ) @@ -42,12 +43,6 @@ newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) deriving (MonadError Compiler.Error) via (StateT LiftCtx Compiler.Pass) deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass) -zipArgsWithArrow :: [Binder] -> Poly.Type -> [(Binder, Poly.Type)] -zipArgsWithArrow (b : bs) (Poly.TBuiltin (Poly.Arrow t ts)) = - (b, t) : zipArgsWithArrow bs ts -zipArgsWithArrow [] _ = [] -zipArgsWithArrow _ _ = error "Expected longer arrow type" - runLiftFn :: LiftFn a -> Compiler.Pass a runLiftFn (LiftFn m) = evalStateT m @@ -100,7 +95,7 @@ liftLambdasTop :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) liftLambdasTop (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam - vs' = zipArgsWithArrow vs t + vs' = zip vs $ fst (collectArrow t) newScope $ catMaybes vs liftedBody <- liftLambdas body liftedLambda <- makeLiftedLambda vs' liftedBody @@ -147,7 +142,7 @@ liftLambdas (I.Prim p exprs t) = do return $ I.Prim p liftedExprs t liftLambdas lam@(I.Lambda _ _ t) = do let (vs, body) = I.collectLambda lam - vs' = zipArgsWithArrow vs t + vs' = zip vs $ fst (collectArrow t) (liftedLamBody, lamFreeTypes) <- descend $ newScope (catMaybes vs) >> liftLambdas body liftedLam <- makeLiftedLambda From ac9b9fe89a77df413db7fdd564743e52cefcc99e Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sat, 27 Nov 2021 21:34:15 -0500 Subject: [PATCH 28/40] Reorder definitions --- src/IR/LambdaLift.hs | 91 ++++++++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 46 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 7c92829c..76899e21 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -84,12 +84,28 @@ newScope vs = modify $ \st -> st , freeTypes = M.empty } -makeLiftedLambda - :: [(I.Binder, Poly.Type)] -> I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) -makeLiftedLambda [] body = return body -makeLiftedLambda ((v, t) : vs) body = do - liftedBody <- makeLiftedLambda vs body - return (I.Lambda v liftedBody $ arrow t (extract liftedBody)) +descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) +descend body = do + savedScope <- gets currentScope + savedFreeTypes <- gets freeTypes + liftedBody <- body + freeTypesBody <- gets freeTypes + modify $ \st -> st { currentScope = savedScope, freeTypes = savedFreeTypes } + return (liftedBody, freeTypesBody) + +liftProgramLambdas + :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) +liftProgramLambdas p = runLiftFn $ do + let defs = I.programDefs p + funs = filter isFun defs + oths = filter (not . isFun) defs + populateGlobalScope defs + funsWithLiftedBodies <- mapM liftLambdasTop funs + liftedLambdas <- gets lifted + return $ p { I.programDefs = oths ++ liftedLambdas ++ funsWithLiftedBodies } + where + isFun (_, I.Lambda{}) = True + isFun _ = False liftLambdasTop :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) @@ -102,32 +118,6 @@ liftLambdasTop (v, lam@(I.Lambda _ _ t)) = do return (v, liftedLambda) liftLambdasTop _ = error "Expected top-level lambda binding" -descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) -descend body = do - savedScope <- gets currentScope - savedFreeTypes <- gets freeTypes - liftedBody <- body - freeTypesBody <- gets freeTypes - modify $ \st -> st { currentScope = savedScope, freeTypes = savedFreeTypes } - return (liftedBody, freeTypesBody) - -liftLambdasInArm - :: (I.Alt, I.Expr Poly.Type) -> LiftFn (I.Alt, I.Expr Poly.Type) -liftLambdasInArm (I.AltLit l, arm) = do - (liftedArm, armFrees) <- descend $ liftLambdas arm - modify $ \st -> st { freeTypes = armFrees } - return (I.AltLit l, liftedArm) -liftLambdasInArm (I.AltDefault b, arm) = do - (liftedArm, armFrees) <- - descend $ F.forM_ b addCurrentScope >> liftLambdas arm - modify $ \st -> st { freeTypes = armFrees } - return (I.AltDefault b, liftedArm) -liftLambdasInArm (I.AltData d bs, arm) = do - (liftedArm, armFrees) <- - descend $ mapM_ addCurrentScope (catMaybes bs) >> liftLambdas arm - modify $ \st -> st { freeTypes = armFrees } - return (I.AltData d bs, liftedArm) - liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v t) = do inScope <- inCurrentScope v @@ -157,7 +147,6 @@ liftLambdas lam@(I.Lambda _ _ t) = do applyFree app (v', t') = case dearrow $ extract app of Just (_, tr) -> I.App app (I.Var v' t') tr Nothing -> app - liftLambdas (I.Let bs e t) = do let vs = map fst bs exprs = map snd bs @@ -172,16 +161,26 @@ liftLambdas (I.Match s arms t) = do liftLambdas lit@I.Lit{} = return lit liftLambdas dat@I.Data{} = return dat -liftProgramLambdas - :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) -liftProgramLambdas p = runLiftFn $ do - let defs = I.programDefs p - funs = filter isFun defs - oths = filter (not . isFun) defs - populateGlobalScope defs - funsWithLiftedBodies <- mapM liftLambdasTop funs - liftedLambdas <- gets lifted - return $ p { I.programDefs = oths ++ liftedLambdas ++ funsWithLiftedBodies } - where - isFun (_, I.Lambda{}) = True - isFun _ = False +liftLambdasInArm + :: (I.Alt, I.Expr Poly.Type) -> LiftFn (I.Alt, I.Expr Poly.Type) +liftLambdasInArm (I.AltLit l, arm) = do + (liftedArm, armFrees) <- descend $ liftLambdas arm + modify $ \st -> st { freeTypes = armFrees } + return (I.AltLit l, liftedArm) +liftLambdasInArm (I.AltDefault b, arm) = do + (liftedArm, armFrees) <- + descend $ F.forM_ b addCurrentScope >> liftLambdas arm + modify $ \st -> st { freeTypes = armFrees } + return (I.AltDefault b, liftedArm) +liftLambdasInArm (I.AltData d bs, arm) = do + (liftedArm, armFrees) <- + descend $ mapM_ addCurrentScope (catMaybes bs) >> liftLambdas arm + modify $ \st -> st { freeTypes = armFrees } + return (I.AltData d bs, liftedArm) + +makeLiftedLambda + :: [(I.Binder, Poly.Type)] -> I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) +makeLiftedLambda [] body = return body +makeLiftedLambda ((v, t) : vs) body = do + liftedBody <- makeLiftedLambda vs body + return (I.Lambda v liftedBody $ arrow t (extract liftedBody)) From 9f089a011f8946e13890a7cd74bdd108e7f3f6ab Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sat, 27 Nov 2021 21:49:28 -0500 Subject: [PATCH 29/40] Use unzip instead of map fst --- src/IR/LambdaLift.hs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 76899e21..a24776e9 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -55,7 +55,7 @@ runLiftFn (LiftFn m) = evalStateT populateGlobalScope :: [(I.VarId, I.Expr Poly.Type)] -> LiftFn () populateGlobalScope defs = do - let globalNames = map fst defs + let (globalNames, _) = unzip defs modify $ \st -> st { globalScope = S.fromList globalNames } inCurrentScope :: I.VarId -> LiftFn Bool @@ -148,8 +148,7 @@ liftLambdas lam@(I.Lambda _ _ t) = do Just (_, tr) -> I.App app (I.Var v' t') tr Nothing -> app liftLambdas (I.Let bs e t) = do - let vs = map fst bs - exprs = map snd bs + let (vs, exprs) = unzip bs mapM_ addCurrentScope (catMaybes vs) liftedLetBodies <- mapM liftLambdas exprs liftedExpr <- liftLambdas e From 7859446fa07c39bb0e8abf6505dffccb71f0a4bb Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sat, 27 Nov 2021 22:11:04 -0500 Subject: [PATCH 30/40] Use IR makeChainedLambda export --- src/IR/IR.hs | 13 ++++++++++++- src/IR/LambdaLift.hs | 20 ++++++-------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/IR/IR.hs b/src/IR/IR.hs index fcb462af..6ae7455e 100644 --- a/src/IR/IR.hs +++ b/src/IR/IR.hs @@ -11,6 +11,7 @@ module IR.IR , DConId(..) , wellFormed , collectLambda + , makeLambdaChain ) where import Common.Identifiers ( Binder , DConId(..) @@ -20,7 +21,10 @@ import Common.Identifiers ( Binder import Control.Comonad ( Comonad(..) ) import Data.Bifunctor ( Bifunctor(..) ) -import IR.Types.TypeSystem ( TypeDef(..) ) +import IR.Types.TypeSystem ( TypeDef(..) + , TypeSystem + , arrow + ) import Common.Pretty @@ -186,6 +190,13 @@ collectLambda (Lambda a b _) = let (as, body) = collectLambda b in (a : as, body) collectLambda e = ([], e) +-- | Create a lambda chain given a list of argument-type pairs and a body. +makeLambdaChain :: TypeSystem t => [(Binder, t)] -> Expr t -> Expr t +makeLambdaChain [] body = body +makeLambdaChain ((v, ty) : vs) body = + let liftedBody = makeLambdaChain vs body + in Lambda v liftedBody $ arrow ty (extract liftedBody) + instance Functor Program where fmap f Program { programEntry = e, programDefs = defs, typeDefs = tds } = Program { programEntry = e diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index a24776e9..6ec3611f 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -6,8 +6,7 @@ import Common.Identifiers import qualified IR.IR as I import qualified IR.Types.Poly as Poly -import IR.Types.TypeSystem ( arrow - , collectArrow +import IR.Types.TypeSystem ( collectArrow , dearrow ) @@ -113,8 +112,8 @@ liftLambdasTop (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam vs' = zip vs $ fst (collectArrow t) newScope $ catMaybes vs - liftedBody <- liftLambdas body - liftedLambda <- makeLiftedLambda vs' liftedBody + liftedBody <- liftLambdas body + let liftedLambda = I.makeLambdaChain vs' liftedBody return (v, liftedLambda) liftLambdasTop _ = error "Expected top-level lambda binding" @@ -135,9 +134,9 @@ liftLambdas lam@(I.Lambda _ _ t) = do vs' = zip vs $ fst (collectArrow t) (liftedLamBody, lamFreeTypes) <- descend $ newScope (catMaybes vs) >> liftLambdas body - liftedLam <- makeLiftedLambda - (map (B.first Just) (M.toList lamFreeTypes) ++ vs') - liftedLamBody + let liftedLam = I.makeLambdaChain + (map (B.first Just) (M.toList lamFreeTypes) ++ vs') + liftedLamBody freshName <- getFresh addLifted freshName liftedLam return $ foldl applyFree @@ -176,10 +175,3 @@ liftLambdasInArm (I.AltData d bs, arm) = do descend $ mapM_ addCurrentScope (catMaybes bs) >> liftLambdas arm modify $ \st -> st { freeTypes = armFrees } return (I.AltData d bs, liftedArm) - -makeLiftedLambda - :: [(I.Binder, Poly.Type)] -> I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) -makeLiftedLambda [] body = return body -makeLiftedLambda ((v, t) : vs) body = do - liftedBody <- makeLiftedLambda vs body - return (I.Lambda v liftedBody $ arrow t (extract liftedBody)) From 4c8321f27d7150b017b11bb7493c78b6f2ae201c Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sun, 28 Nov 2021 15:52:16 -0500 Subject: [PATCH 31/40] Preserve relative top def order after lifting --- src/IR/LambdaLift.hs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 6ec3611f..1530caf5 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -92,30 +92,31 @@ descend body = do modify $ \st -> st { currentScope = savedScope, freeTypes = savedFreeTypes } return (liftedBody, freeTypesBody) +extractLifted :: LiftFn a -> LiftFn (a, [(I.VarId, I.Expr Poly.Type)]) +extractLifted body = do + liftedBody <- body + newTopDefs <- gets lifted + modify $ \st -> st { lifted = [] } + return (liftedBody, newTopDefs) + liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) liftProgramLambdas p = runLiftFn $ do let defs = I.programDefs p - funs = filter isFun defs - oths = filter (not . isFun) defs populateGlobalScope defs - funsWithLiftedBodies <- mapM liftLambdasTop funs - liftedLambdas <- gets lifted - return $ p { I.programDefs = oths ++ liftedLambdas ++ funsWithLiftedBodies } - where - isFun (_, I.Lambda{}) = True - isFun _ = False + liftedProgramDefs <- mapM liftLambdasTop defs + return $ p { I.programDefs = concat liftedProgramDefs } liftLambdasTop - :: (I.VarId, I.Expr Poly.Type) -> LiftFn (I.VarId, I.Expr Poly.Type) + :: (I.VarId, I.Expr Poly.Type) -> LiftFn [(I.VarId, I.Expr Poly.Type)] liftLambdasTop (v, lam@(I.Lambda _ _ t)) = do let (vs, body) = I.collectLambda lam vs' = zip vs $ fst (collectArrow t) newScope $ catMaybes vs - liftedBody <- liftLambdas body + (liftedBody, newTopDefs) <- extractLifted $ liftLambdas body let liftedLambda = I.makeLambdaChain vs' liftedBody - return (v, liftedLambda) -liftLambdasTop _ = error "Expected top-level lambda binding" + return $ newTopDefs ++ [(v, liftedLambda)] +liftLambdasTop topDef = return [topDef] liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v t) = do From d6578b9f91ef7a7c3cb2cbdb9689bc865df2d8bc Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sun, 28 Nov 2021 16:59:47 -0500 Subject: [PATCH 32/40] Write some very epic haddock comments for LambdaLift module --- src/IR/LambdaLift.hs | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 1530caf5..1eabee04 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DerivingVia #-} -module IR.LambdaLift where +module IR.LambdaLift + ( liftProgramLambdas + ) where import qualified Common.Compiler as Compiler import Common.Identifiers @@ -26,6 +28,7 @@ import qualified Data.Map as M import Data.Maybe ( catMaybes ) import qualified Data.Set as S +-- | Lifting Environment data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId , currentScope :: S.Set I.VarId @@ -34,6 +37,7 @@ data LiftCtx = LiftCtx , anonCount :: Int } +-- Lift Monad newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) deriving Functor via (StateT LiftCtx Compiler.Pass) deriving Applicative via (StateT LiftCtx Compiler.Pass) @@ -42,6 +46,7 @@ newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a) deriving (MonadError Compiler.Error) via (StateT LiftCtx Compiler.Pass) deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass) +-- | Run a LiftFn computation. runLiftFn :: LiftFn a -> Compiler.Pass a runLiftFn (LiftFn m) = evalStateT m @@ -52,37 +57,45 @@ runLiftFn (LiftFn m) = evalStateT , anonCount = 0 } +-- | Extract top level definition names that compose program's global scope. populateGlobalScope :: [(I.VarId, I.Expr Poly.Type)] -> LiftFn () populateGlobalScope defs = do let (globalNames, _) = unzip defs modify $ \st -> st { globalScope = S.fromList globalNames } +-- | Check if an identifier is in the current program scope. inCurrentScope :: I.VarId -> LiftFn Bool inCurrentScope v = S.member v <$> gets currentScope +-- | Add an identifier to the current program scope. addCurrentScope :: I.VarId -> LiftFn () addCurrentScope v = modify $ \st -> st { currentScope = S.insert v $ currentScope st } +-- | Get a fresh variable name getFresh :: LiftFn String getFresh = do curCount <- gets anonCount modify $ \st -> st { anonCount = anonCount st + 1 } return $ "anon" ++ show curCount +-- | Store a new lifted lambda to later add to the program's top level definitions. addLifted :: String -> I.Expr Poly.Type -> LiftFn () addLifted name lam = modify $ \st -> st { lifted = (I.VarId (Identifier name), lam) : lifted st } +-- | Register a (free variable, type) mapping for the current program scope. addFreeVar :: I.VarId -> Poly.Type -> LiftFn () addFreeVar v t = modify $ \st -> st { freeTypes = M.insert v t $ freeTypes st } +-- | Update lift environment before entering a new scope (e.g. lambda body, match arm). newScope :: [I.VarId] -> LiftFn () newScope vs = modify $ \st -> st { currentScope = S.union (globalScope st) (S.fromList vs) , freeTypes = M.empty } +-- | Context management for lifting new scopes (restore information after lifting the body). descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) descend body = do savedScope <- gets currentScope @@ -92,6 +105,7 @@ descend body = do modify $ \st -> st { currentScope = savedScope, freeTypes = savedFreeTypes } return (liftedBody, freeTypesBody) +-- | Context management for lifting top level lambda definitions. extractLifted :: LiftFn a -> LiftFn (a, [(I.VarId, I.Expr Poly.Type)]) extractLifted body = do liftedBody <- body @@ -99,6 +113,11 @@ extractLifted body = do modify $ \st -> st { lifted = [] } return (liftedBody, newTopDefs) +{- | Entry-point to lambda lifting. + +Maps over top level definitions and lifts out lambda definitions to create a new +Program with the relative order of user definitions preserved. +-} liftProgramLambdas :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type) liftProgramLambdas p = runLiftFn $ do @@ -107,6 +126,7 @@ liftProgramLambdas p = runLiftFn $ do liftedProgramDefs <- mapM liftLambdasTop defs return $ p { I.programDefs = concat liftedProgramDefs } +-- | Given a top-level definition, lift out any lambda definitions. liftLambdasTop :: (I.VarId, I.Expr Poly.Type) -> LiftFn [(I.VarId, I.Expr Poly.Type)] liftLambdasTop (v, lam@(I.Lambda _ _ t)) = do @@ -118,6 +138,14 @@ liftLambdasTop (v, lam@(I.Lambda _ _ t)) = do return $ newTopDefs ++ [(v, liftedLambda)] liftLambdasTop topDef = return [topDef] +{- | Lifting logic for IR expressions. + +As we traverse over IR expressions, we note down any bindings we encounter so +that we can detect free variables. For lambda definitions, we use free +variables to create a new top-level lifted equivalent and then adjust the +callsite by partially-applying the new lifted lambda with those free variables +from the surrounding the scope. +-} liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type) liftLambdas n@(I.Var v t) = do inScope <- inCurrentScope v @@ -160,6 +188,7 @@ liftLambdas (I.Match s arms t) = do liftLambdas lit@I.Lit{} = return lit liftLambdas dat@I.Data{} = return dat +-- | Entry point for traversing the arms of match expressions. liftLambdasInArm :: (I.Alt, I.Expr Poly.Type) -> LiftFn (I.Alt, I.Expr Poly.Type) liftLambdasInArm (I.AltLit l, arm) = do From efef4e8314d6055a5c3307131555c890c24530bd Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sun, 28 Nov 2021 19:47:37 -0500 Subject: [PATCH 33/40] Add lambda lifting tests --- package.yaml | 9 ++ src/IR/Types/Poly.hs | 2 +- test/ir-to-ir/Spec.hs | 1 + test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs | 84 +++++++++++++++++++ 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 test/ir-to-ir/Spec.hs create mode 100644 test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs diff --git a/package.yaml b/package.yaml index ed0e5f02..89389f2e 100644 --- a/package.yaml +++ b/package.yaml @@ -90,3 +90,12 @@ tests: - hspec - containers - comonad + ir-to-ir-test: + main: Spec.hs + source-dirs: + - test/ir-to-ir + dependencies: + - sslang + - hspec + - containers + - comonad diff --git a/src/IR/Types/Poly.hs b/src/IR/Types/Poly.hs index d4934b03..871b5e74 100644 --- a/src/IR/Types/Poly.hs +++ b/src/IR/Types/Poly.hs @@ -44,7 +44,7 @@ data Type = TBuiltin (Builtin Type) -- ^ Builtin types | TCon TConId [Type] -- ^ Type constructor, e.g., @Option '0@ | TVar TVarIdx -- ^ Type variables, e.g., @'0@ - deriving Eq + deriving (Eq, Show) instance TypeSystem Type where projectBuiltin = TBuiltin diff --git a/test/ir-to-ir/Spec.hs b/test/ir-to-ir/Spec.hs new file mode 100644 index 00000000..a824f8c3 --- /dev/null +++ b/test/ir-to-ir/Spec.hs @@ -0,0 +1 @@ +{-# OPTIONS_GHC -F -pgmF hspec-discover #-} diff --git a/test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs b/test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs new file mode 100644 index 00000000..3084af4b --- /dev/null +++ b/test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs @@ -0,0 +1,84 @@ +module Tests.LiftProgramLambdasSpec where + +import qualified IR.IR as I +import qualified IR.Types.Classes as C + +import IR.LambdaLift +import Common.Pretty ( spaghetti ) +import Common.Compiler ( runPass ) +import Front.Ast +import Front.Parser ( parseProgram ) +import Front.ParseOperators ( parseOperators ) +import IR.LowerAst ( lowerProgram ) +import IR.ClassInstantiation (instProgram) +import IR.Types.Poly as Poly +import IR.TypeInference ( inferExpr + , inferProgram + ) + +import Control.Comonad ( Comonad(..) ) +import Data.Bifunctor ( Bifunctor(second) ) +import Test.Hspec ( Spec(..) + , it + , pending + , shouldBe + , describe + ) + +lowerAndLift :: Either String Program -> Either String (I.Program Poly.Type) +lowerAndLift (Left e) = Left "Failed to parse program" +lowerAndLift (Right p) = + case runPass $ lowerProgram (parseOperators p) >>= inferProgram >>= instProgram >>= liftProgramLambdas of + Left e' -> Left $ show e' + Right p' -> Right p' + +spec :: Spec +spec = do + + it "lifts a lambda without free variables" $ do + let nonLiftedProgram = parseProgram $ unlines + [ "bar: Int = 5" + , "baz x: Int -> Int =" + , " x + 1" + , "foo y: Int -> Int =" + , " let adder (z: Int) -> Int = z + 1" + , " adder y" + ] + liftedProgram = parseProgram $ unlines + [ "bar: Int = 5" + , "baz x: Int -> Int =" + , " x + 1" + , "anon0 z: Int -> Int =" + , " z + 1" + , "foo y: Int -> Int =" + , " let adder = anon0" + , " adder y" + ] + nestedToLifted = lowerAndLift nonLiftedProgram + liftedToLifted = lowerAndLift liftedProgram + nestedToLifted `shouldBe` liftedToLifted + + it "lifts a lambda with free variables" $ do + let nonLiftedProgram = parseProgram $ unlines + [ "bar: Int = 5" + , "baz x: Int -> Int =" + , " x + 1" + , "foo y: Int -> Int =" + , " let w = 1" + , " adder (z: Int) -> Int = z + bar + w" + , " adder y" + ] + liftedProgram = parseProgram $ unlines + [ "bar: Int = 5" + , "baz x: Int -> Int =" + , " x + 1" + , "anon0 (w: Int) (z: Int) -> Int =" + , " z + bar + w" + , "foo y: Int -> Int =" + , " let w = 1" + , " adder = anon0 w" + , " adder y" + ] + nestedToLifted = lowerAndLift nonLiftedProgram + liftedToLifted = lowerAndLift liftedProgram + nestedToLifted `shouldBe` liftedToLifted From 30e1cdd7ee7e19fe7459490313be4b35a1ea3530 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Sun, 28 Nov 2021 20:02:04 -0500 Subject: [PATCH 34/40] Remove type variables from state monad helpers --- src/IR/LambdaLift.hs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 1eabee04..a1cbc5a1 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -96,7 +96,9 @@ newScope vs = modify $ \st -> st } -- | Context management for lifting new scopes (restore information after lifting the body). -descend :: LiftFn a -> LiftFn (a, M.Map VarId Poly.Type) +descend + :: LiftFn (I.Expr Poly.Type) + -> LiftFn (I.Expr Poly.Type, M.Map VarId Poly.Type) descend body = do savedScope <- gets currentScope savedFreeTypes <- gets freeTypes @@ -106,7 +108,9 @@ descend body = do return (liftedBody, freeTypesBody) -- | Context management for lifting top level lambda definitions. -extractLifted :: LiftFn a -> LiftFn (a, [(I.VarId, I.Expr Poly.Type)]) +extractLifted + :: LiftFn (I.Expr Poly.Type) + -> LiftFn (I.Expr Poly.Type, [(I.VarId, I.Expr Poly.Type)]) extractLifted body = do liftedBody <- body newTopDefs <- gets lifted From 36fc1c132fe02ec89217c3eb5c78cc105418fd4d Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 7 Dec 2021 13:33:48 -0500 Subject: [PATCH 35/40] Tweak comments and lam lift compiler option name --- app/Main.hs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/app/Main.hs b/app/Main.hs index 18b844cd..e1913644 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -28,12 +28,12 @@ import System.IO ( hPrint ) data Mode - = DumpTokens -- ^ Token stream before parsing - | DumpAST -- ^ AST before operator parsing - | DumpASTP -- ^ AST after operators are parsed - | DumpIR -- ^ Intermediate representation - | DumpIRLifted -- ^ Intermediate representation with lifted lambdas - | GenerateC -- ^ Generate C backend + = DumpTokens -- ^ Token stream before parsing + | DumpAST -- ^ AST before operator parsing + | DumpASTP -- ^ AST after operators are parsed + | DumpIR -- ^ Intermediate representation + | DumpIRLambdasLifted -- ^ Intermediate representation with lifted lambdas + | GenerateC -- ^ Generate C backend deriving (Eq, Show) data Options = Options @@ -71,8 +71,8 @@ optionDescriptions = (NoArg (\opt -> return opt { optMode = DumpIR })) "Print the IR" , Option "" - ["dump-lifted-ir"] - (NoArg (\opt -> return opt { optMode = DumpIRLifted })) + ["dump-lambdas-lifted-ir"] + (NoArg (\opt -> return opt { optMode = DumpIRLambdasLifted })) "Print the IR with lifted lambdas" , Option "" ["generate-c"] @@ -134,7 +134,7 @@ main = do irL <- doPass $ IR.lambdaLift irY - when (optMode opts == DumpIRLifted) $ putStrLn (spaghetti irL) >> exitSuccess + when (optMode opts == DumpIRLambdasLifted) $ putStrLn (spaghetti irL) >> exitSuccess irI <- doPass $ IR.defunctionalize irL From d58235c456bd25abecec8d1306e52bb304dde56b Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 7 Dec 2021 13:41:24 -0500 Subject: [PATCH 36/40] Use epic foldr in makeLambdaChain --- src/IR/IR.hs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/IR/IR.hs b/src/IR/IR.hs index 6ae7455e..2fd56d63 100644 --- a/src/IR/IR.hs +++ b/src/IR/IR.hs @@ -192,10 +192,9 @@ collectLambda e = ([], e) -- | Create a lambda chain given a list of argument-type pairs and a body. makeLambdaChain :: TypeSystem t => [(Binder, t)] -> Expr t -> Expr t -makeLambdaChain [] body = body -makeLambdaChain ((v, ty) : vs) body = - let liftedBody = makeLambdaChain vs body - in Lambda v liftedBody $ arrow ty (extract liftedBody) +makeLambdaChain [] body = body +makeLambdaChain args body = foldr chain body args + where chain (v, t) b = Lambda v b $ t `arrow` extract b instance Functor Program where fmap f Program { programEntry = e, programDefs = defs, typeDefs = tds } = From 8003c49facba0616badf34f954e35b1df06fc3eb Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 7 Dec 2021 14:20:01 -0500 Subject: [PATCH 37/40] Add docs for LiftCtx fields --- src/IR/LambdaLift.hs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index a1cbc5a1..44d4961d 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -31,10 +31,23 @@ import qualified Data.Set as S -- | Lifting Environment data LiftCtx = LiftCtx { globalScope :: S.Set I.VarId + {- ^ `globalScope` is a set containing top-level identifiers. All scopes, + regardless of depth, have access to these identifiers. + -} , currentScope :: S.Set I.VarId + {- ^ 'currentScope` is a set containing the identifiers available in the + current scope. + -} , freeTypes :: M.Map I.VarId Poly.Type + {- ^ `freeTypes` maps an identifier for a free variable to its type. -} , lifted :: [(I.VarId, I.Expr Poly.Type)] + {- ^ `lifted` is a list of lifted lambdas created while descending into a + top-level definition. + -} , anonCount :: Int + {- ^ `anonCount` is a monotonically increasing counter used for creating + unique identifiers for lifted lambdas. + -} } -- Lift Monad From b1098d6d1ea6cf46a6df7f768ccf35b87aaa30c2 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 7 Dec 2021 14:33:05 -0500 Subject: [PATCH 38/40] Remove redundant base case from makeLambdaChain --- src/IR/IR.hs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/IR/IR.hs b/src/IR/IR.hs index 2fd56d63..1dff0878 100644 --- a/src/IR/IR.hs +++ b/src/IR/IR.hs @@ -192,7 +192,6 @@ collectLambda e = ([], e) -- | Create a lambda chain given a list of argument-type pairs and a body. makeLambdaChain :: TypeSystem t => [(Binder, t)] -> Expr t -> Expr t -makeLambdaChain [] body = body makeLambdaChain args body = foldr chain body args where chain (v, t) b = Lambda v b $ t `arrow` extract b From bd51557881ec19c92262ba16c015d156e5f65b16 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 7 Dec 2021 17:02:59 -0500 Subject: [PATCH 39/40] Take into account that after descending, we may inherit more free variables --- src/IR/LambdaLift.hs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/IR/LambdaLift.hs b/src/IR/LambdaLift.hs index 44d4961d..191d37ea 100644 --- a/src/IR/LambdaLift.hs +++ b/src/IR/LambdaLift.hs @@ -117,7 +117,11 @@ descend body = do savedFreeTypes <- gets freeTypes liftedBody <- body freeTypesBody <- gets freeTypes - modify $ \st -> st { currentScope = savedScope, freeTypes = savedFreeTypes } + modify $ \st -> st + { currentScope = savedScope + , freeTypes = M.union (S.foldl (flip M.delete) freeTypesBody savedScope) + savedFreeTypes + } return (liftedBody, freeTypesBody) -- | Context management for lifting top level lambda definitions. @@ -128,7 +132,7 @@ extractLifted body = do liftedBody <- body newTopDefs <- gets lifted modify $ \st -> st { lifted = [] } - return (liftedBody, newTopDefs) + return (liftedBody, reverse newTopDefs) {- | Entry-point to lambda lifting. From e609da79521e42fec8ed74d842feb196c9401767 Mon Sep 17 00:00:00 2001 From: Hans Montero Date: Tue, 7 Dec 2021 17:03:13 -0500 Subject: [PATCH 40/40] Add extra nested lambda test --- test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs b/test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs index 3084af4b..1542c536 100644 --- a/test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs +++ b/test/ir-to-ir/Tests/LiftProgramLambdasSpec.hs @@ -82,3 +82,27 @@ spec = do nestedToLifted = lowerAndLift nonLiftedProgram liftedToLifted = lowerAndLift liftedProgram nestedToLifted `shouldBe` liftedToLifted + + it "lifts nested lambdas with free variables" $ do + let nonLiftedProgram = parseProgram $ unlines + [ "foo (x: Int) (y: Int) -> Int =" + , " let z = 5" + , " g (a: Int) -> Int =" + , " let h (b: Int) -> Int = a + b + x" + , " h z" + , " g y" + ] + liftedProgram = parseProgram $ unlines + [ "anon0 (a: Int) (x: Int) (b: Int) -> Int =" + , " a + b + x" + , "anon1 (x: Int) (z: Int) (a: Int) -> Int =" + , " let h = anon0 a x" + , " h z" + , "foo (x: Int) (y: Int) -> Int =" + , " let z = 5" + , " g = anon1 x z" + , " g y" + ] + nestedToLifted = lowerAndLift nonLiftedProgram + liftedToLifted = lowerAndLift liftedProgram + nestedToLifted `shouldBe` liftedToLifted