Skip to content

Commit

Permalink
Lift lambdas (#44)
Browse files Browse the repository at this point in the history
* WIP

* mvp: detect free vars

* Improve free variable collection

* wip: Lift lambdas using state monad

* wip: Lift let expressions, start testing, but still no callsite adjustments for named lambdas

* Use a jumping off function for lambda lifting

* Don't treat 'named' lambdas specially

* Format code

* Lift prims

* Remove unneeded code

* Tweak coding style

* Fix free var detection bug

* Exploit very epic laziness to descend into lambda bodies

* Add compiler option for dumping lifted IR

* Reconstruct lambdas with correct types

* Fix lambda free-var application types

* Add lifting for match expressions

* Rebase on main and update match code

* Use makeLiftedLambda and insert some trace

* Use function type signature for top level lambda collection

* Use makeLiftedLambda in liftLambdas' after regression test fix in main

* Use extract

* Rewrite applyFreesToLambda as an epic fold

* Don't use catch-all in liftLambdas

* Rename liftLambdas' to liftLambdasTop

* Use 'arrow' instead of redefining it

* Use collectArrow

* Reorder definitions

* Use unzip instead of map fst

* Use IR makeChainedLambda export

* Preserve relative top def order after lifting

* Write some very epic haddock comments for LambdaLift module

* Add lambda lifting tests

* Remove type variables from state monad helpers

* Tweak comments and lam lift compiler option name

* Use epic foldr in makeLambdaChain

* Add docs for LiftCtx fields

* Remove redundant base case from makeLambdaChain

* Take into account that after descending, we may inherit more free variables

* Add extra nested lambda test
  • Loading branch information
hmontero1205 authored Dec 8, 2021
1 parent 66e86e9 commit e88ed8b
Show file tree
Hide file tree
Showing 8 changed files with 376 additions and 13 deletions.
27 changes: 17 additions & 10 deletions app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ import System.IO ( hPrint
)

data Mode
= DumpTokens -- ^ Token stream before parsing
| DumpAST -- ^ AST before operator parsing
| DumpASTP -- ^ AST after operators are parsed
| DumpIR -- ^ Intermediate representation
| GenerateC -- ^ Generate C backend
= DumpTokens -- ^ Token stream before parsing
| DumpAST -- ^ AST before operator parsing
| DumpASTP -- ^ AST after operators are parsed
| DumpIR -- ^ Intermediate representation
| DumpIRLambdasLifted -- ^ Intermediate representation with lifted lambdas
| GenerateC -- ^ Generate C backend
deriving (Eq, Show)

data Options = Options
Expand Down Expand Up @@ -69,6 +70,10 @@ optionDescriptions =
["dump-ir"]
(NoArg (\opt -> return opt { optMode = DumpIR }))
"Print the IR"
, Option ""
["dump-lambdas-lifted-ir"]
(NoArg (\opt -> return opt { optMode = DumpIRLambdasLifted }))
"Print the IR with lifted lambdas"
, Option ""
["generate-c"]
(NoArg (\opt -> return opt { optMode = GenerateC }))
Expand Down Expand Up @@ -113,21 +118,23 @@ main = do
when (optMode opts == DumpAST) $ putStrLn (spaghetti ast) >> exitSuccess

ast' <- doPass $ Front.desugarAst ast
when (optMode opts == DumpASTP) $ putStrLn (spaghetti ast') >> exitSuccess
when (optMode opts == DumpASTP) $ putStrLn (spaghetti ast') >> exitSuccess

() <- doPass $ Front.checkAst ast'

irA <- doPass $ IR.lowerAst ast'

when (optMode opts == DumpIR) $ putStrLn (spaghetti irA) >> exitSuccess

irC <- doPass $ IR.inferTypes irA
irC <- doPass $ IR.inferTypes irA

irP <- doPass $ IR.instantiateClasses irC

irP <- doPass $ IR.instantiateClasses irC
irY <- doPass $ IR.yieldAbstraction irP

irY <- doPass $ IR.yieldAbstraction irP
irL <- doPass $ IR.lambdaLift irY

irL <- doPass $ IR.lambdaLift irY
when (optMode opts == DumpIRLambdasLifted) $ putStrLn (spaghetti irL) >> exitSuccess

irI <- doPass $ IR.defunctionalize irL

Expand Down
9 changes: 9 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,12 @@ tests:
- hspec
- containers
- comonad
ir-to-ir-test:
main: Spec.hs
source-dirs:
- test/ir-to-ir
dependencies:
- sslang
- hspec
- containers
- comonad
3 changes: 2 additions & 1 deletion src/IR.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import IR.ClassInstantiation ( instProgram )
import IR.LowerAst ( lowerProgram )
import IR.Monomorphize ( monoProgram )
import IR.TypeInference ( inferProgram )
import IR.LambdaLift ( liftProgramLambdas )

lowerAst :: A.Program -> Compiler.Pass (I.Program Ann.Type)
lowerAst = lowerProgram
Expand All @@ -30,7 +31,7 @@ yieldAbstraction :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type)
yieldAbstraction = return

lambdaLift :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type)
lambdaLift = return
lambdaLift = liftProgramLambdas

defunctionalize :: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type)
defunctionalize = return
Expand Down
11 changes: 10 additions & 1 deletion src/IR/IR.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module IR.IR
, DConId(..)
, wellFormed
, collectLambda
, makeLambdaChain
) where
import Common.Identifiers ( Binder
, DConId(..)
Expand All @@ -20,7 +21,10 @@ import Common.Identifiers ( Binder

import Control.Comonad ( Comonad(..) )
import Data.Bifunctor ( Bifunctor(..) )
import IR.Types.TypeSystem ( TypeDef(..) )
import IR.Types.TypeSystem ( TypeDef(..)
, TypeSystem
, arrow
)

import Common.Pretty

Expand Down Expand Up @@ -186,6 +190,11 @@ collectLambda (Lambda a b _) =
let (as, body) = collectLambda b in (a : as, body)
collectLambda e = ([], e)

-- | Create a lambda chain given a list of argument-type pairs and a body.
makeLambdaChain :: TypeSystem t => [(Binder, t)] -> Expr t -> Expr t
makeLambdaChain args body = foldr chain body args
where chain (v, t) b = Lambda v b $ t `arrow` extract b

instance Functor Program where
fmap f Program { programEntry = e, programDefs = defs, typeDefs = tds } =
Program { programEntry = e
Expand Down
228 changes: 228 additions & 0 deletions src/IR/LambdaLift.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
{-# LANGUAGE DerivingVia #-}
module IR.LambdaLift
( liftProgramLambdas
) where

import qualified Common.Compiler as Compiler
import Common.Identifiers
import qualified IR.IR as I

import qualified IR.Types.Poly as Poly
import IR.Types.TypeSystem ( collectArrow
, dearrow
)

import Control.Comonad ( Comonad(..) )
import Control.Monad.Except ( MonadError(..) )
import Control.Monad.State.Lazy ( MonadState
, StateT(..)
, evalStateT
, gets
, modify
, unless
)

import qualified Data.Bifunctor as B
import qualified Data.Foldable as F
import qualified Data.Map as M
import Data.Maybe ( catMaybes )
import qualified Data.Set as S

-- | Lifting Environment
data LiftCtx = LiftCtx
{ globalScope :: S.Set I.VarId
{- ^ `globalScope` is a set containing top-level identifiers. All scopes,
regardless of depth, have access to these identifiers.
-}
, currentScope :: S.Set I.VarId
{- ^ 'currentScope` is a set containing the identifiers available in the
current scope.
-}
, freeTypes :: M.Map I.VarId Poly.Type
{- ^ `freeTypes` maps an identifier for a free variable to its type. -}
, lifted :: [(I.VarId, I.Expr Poly.Type)]
{- ^ `lifted` is a list of lifted lambdas created while descending into a
top-level definition.
-}
, anonCount :: Int
{- ^ `anonCount` is a monotonically increasing counter used for creating
unique identifiers for lifted lambdas.
-}
}

-- Lift Monad
newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a)
deriving Functor via (StateT LiftCtx Compiler.Pass)
deriving Applicative via (StateT LiftCtx Compiler.Pass)
deriving Monad via (StateT LiftCtx Compiler.Pass)
deriving MonadFail via (StateT LiftCtx Compiler.Pass)
deriving (MonadError Compiler.Error) via (StateT LiftCtx Compiler.Pass)
deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass)

-- | Run a LiftFn computation.
runLiftFn :: LiftFn a -> Compiler.Pass a
runLiftFn (LiftFn m) = evalStateT
m
LiftCtx { globalScope = S.empty
, currentScope = S.empty
, freeTypes = M.empty
, lifted = []
, anonCount = 0
}

-- | Extract top level definition names that compose program's global scope.
populateGlobalScope :: [(I.VarId, I.Expr Poly.Type)] -> LiftFn ()
populateGlobalScope defs = do
let (globalNames, _) = unzip defs
modify $ \st -> st { globalScope = S.fromList globalNames }

-- | Check if an identifier is in the current program scope.
inCurrentScope :: I.VarId -> LiftFn Bool
inCurrentScope v = S.member v <$> gets currentScope

-- | Add an identifier to the current program scope.
addCurrentScope :: I.VarId -> LiftFn ()
addCurrentScope v =
modify $ \st -> st { currentScope = S.insert v $ currentScope st }

-- | Get a fresh variable name
getFresh :: LiftFn String
getFresh = do
curCount <- gets anonCount
modify $ \st -> st { anonCount = anonCount st + 1 }
return $ "anon" ++ show curCount

-- | Store a new lifted lambda to later add to the program's top level definitions.
addLifted :: String -> I.Expr Poly.Type -> LiftFn ()
addLifted name lam =
modify $ \st -> st { lifted = (I.VarId (Identifier name), lam) : lifted st }

-- | Register a (free variable, type) mapping for the current program scope.
addFreeVar :: I.VarId -> Poly.Type -> LiftFn ()
addFreeVar v t = modify $ \st -> st { freeTypes = M.insert v t $ freeTypes st }

-- | Update lift environment before entering a new scope (e.g. lambda body, match arm).
newScope :: [I.VarId] -> LiftFn ()
newScope vs = modify $ \st -> st
{ currentScope = S.union (globalScope st) (S.fromList vs)
, freeTypes = M.empty
}

-- | Context management for lifting new scopes (restore information after lifting the body).
descend
:: LiftFn (I.Expr Poly.Type)
-> LiftFn (I.Expr Poly.Type, M.Map VarId Poly.Type)
descend body = do
savedScope <- gets currentScope
savedFreeTypes <- gets freeTypes
liftedBody <- body
freeTypesBody <- gets freeTypes
modify $ \st -> st
{ currentScope = savedScope
, freeTypes = M.union (S.foldl (flip M.delete) freeTypesBody savedScope)
savedFreeTypes
}
return (liftedBody, freeTypesBody)

-- | Context management for lifting top level lambda definitions.
extractLifted
:: LiftFn (I.Expr Poly.Type)
-> LiftFn (I.Expr Poly.Type, [(I.VarId, I.Expr Poly.Type)])
extractLifted body = do
liftedBody <- body
newTopDefs <- gets lifted
modify $ \st -> st { lifted = [] }
return (liftedBody, reverse newTopDefs)

{- | Entry-point to lambda lifting.
Maps over top level definitions and lifts out lambda definitions to create a new
Program with the relative order of user definitions preserved.
-}
liftProgramLambdas
:: I.Program Poly.Type -> Compiler.Pass (I.Program Poly.Type)
liftProgramLambdas p = runLiftFn $ do
let defs = I.programDefs p
populateGlobalScope defs
liftedProgramDefs <- mapM liftLambdasTop defs
return $ p { I.programDefs = concat liftedProgramDefs }

-- | Given a top-level definition, lift out any lambda definitions.
liftLambdasTop
:: (I.VarId, I.Expr Poly.Type) -> LiftFn [(I.VarId, I.Expr Poly.Type)]
liftLambdasTop (v, lam@(I.Lambda _ _ t)) = do
let (vs, body) = I.collectLambda lam
vs' = zip vs $ fst (collectArrow t)
newScope $ catMaybes vs
(liftedBody, newTopDefs) <- extractLifted $ liftLambdas body
let liftedLambda = I.makeLambdaChain vs' liftedBody
return $ newTopDefs ++ [(v, liftedLambda)]
liftLambdasTop topDef = return [topDef]

{- | Lifting logic for IR expressions.
As we traverse over IR expressions, we note down any bindings we encounter so
that we can detect free variables. For lambda definitions, we use free
variables to create a new top-level lifted equivalent and then adjust the
callsite by partially-applying the new lifted lambda with those free variables
from the surrounding the scope.
-}
liftLambdas :: I.Expr Poly.Type -> LiftFn (I.Expr Poly.Type)
liftLambdas n@(I.Var v t) = do
inScope <- inCurrentScope v
unless inScope $ addFreeVar v t
return n
liftLambdas (I.App e1 e2 t) = do
liftedE1 <- liftLambdas e1
liftedE2 <- liftLambdas e2
return $ I.App liftedE1 liftedE2 t
liftLambdas (I.Prim p exprs t) = do
liftedExprs <- mapM liftLambdas exprs
return $ I.Prim p liftedExprs t
liftLambdas lam@(I.Lambda _ _ t) = do
let (vs, body) = I.collectLambda lam
vs' = zip vs $ fst (collectArrow t)
(liftedLamBody, lamFreeTypes) <-
descend $ newScope (catMaybes vs) >> liftLambdas body
let liftedLam = I.makeLambdaChain
(map (B.first Just) (M.toList lamFreeTypes) ++ vs')
liftedLamBody
freshName <- getFresh
addLifted freshName liftedLam
return $ foldl applyFree
(I.Var (I.VarId (Identifier freshName)) (extract liftedLam))
(M.toList lamFreeTypes)
where
applyFree app (v', t') = case dearrow $ extract app of
Just (_, tr) -> I.App app (I.Var v' t') tr
Nothing -> app
liftLambdas (I.Let bs e t) = do
let (vs, exprs) = unzip bs
mapM_ addCurrentScope (catMaybes vs)
liftedLetBodies <- mapM liftLambdas exprs
liftedExpr <- liftLambdas e
return $ I.Let (zip vs liftedLetBodies) liftedExpr t
liftLambdas (I.Match s arms t) = do
liftedMatch <- liftLambdas s
liftedArms <- mapM liftLambdasInArm arms
return $ I.Match liftedMatch liftedArms t
liftLambdas lit@I.Lit{} = return lit
liftLambdas dat@I.Data{} = return dat

-- | Entry point for traversing the arms of match expressions.
liftLambdasInArm
:: (I.Alt, I.Expr Poly.Type) -> LiftFn (I.Alt, I.Expr Poly.Type)
liftLambdasInArm (I.AltLit l, arm) = do
(liftedArm, armFrees) <- descend $ liftLambdas arm
modify $ \st -> st { freeTypes = armFrees }
return (I.AltLit l, liftedArm)
liftLambdasInArm (I.AltDefault b, arm) = do
(liftedArm, armFrees) <-
descend $ F.forM_ b addCurrentScope >> liftLambdas arm
modify $ \st -> st { freeTypes = armFrees }
return (I.AltDefault b, liftedArm)
liftLambdasInArm (I.AltData d bs, arm) = do
(liftedArm, armFrees) <-
descend $ mapM_ addCurrentScope (catMaybes bs) >> liftLambdas arm
modify $ \st -> st { freeTypes = armFrees }
return (I.AltData d bs, liftedArm)
2 changes: 1 addition & 1 deletion src/IR/Types/Poly.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ data Type
= TBuiltin (Builtin Type) -- ^ Builtin types
| TCon TConId [Type] -- ^ Type constructor, e.g., @Option '0@
| TVar TVarIdx -- ^ Type variables, e.g., @'0@
deriving Eq
deriving (Eq, Show)

instance TypeSystem Type where
projectBuiltin = TBuiltin
Expand Down
1 change: 1 addition & 0 deletions test/ir-to-ir/Spec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{-# OPTIONS_GHC -F -pgmF hspec-discover #-}
Loading

0 comments on commit e88ed8b

Please sign in to comment.