From 5e8188e2d6c9e0501cba9dbcf5b74eb5e5797dfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oskar=20Wickstr=C3=B6m?= Date: Sun, 24 Jan 2016 17:00:41 +0100 Subject: [PATCH] Subsume type signatures with expression types --- src/oden/Oden/Backend/Go.hs | 29 +++++++----- src/oden/Oden/Compiler.hs | 12 +++-- src/oden/Oden/Infer.hs | 73 +++++------------------------ src/oden/Oden/Infer/Substitution.hs | 63 +++++++++++++++++++++++++ src/oden/Oden/Infer/Subsumption.hs | 69 +++++++++++++++++++++++++++ src/oden/Oden/Output/Infer.hs | 20 +++++--- test/Oden/CompilerSpec.hs | 5 ++ test/Oden/Infer/SubsumptionSpec.hs | 53 +++++++++++++++++++++ test/Oden/InferSpec.hs | 66 ++++++++++++++++++++++++++ 9 files changed, 306 insertions(+), 84 deletions(-) create mode 100644 src/oden/Oden/Infer/Substitution.hs create mode 100644 src/oden/Oden/Infer/Subsumption.hs create mode 100644 test/Oden/Infer/SubsumptionSpec.hs diff --git a/src/oden/Oden/Backend/Go.hs b/src/oden/Oden/Backend/Go.hs index e59bbe8..b352eb3 100644 --- a/src/oden/Oden/Backend/Go.hs +++ b/src/oden/Oden/Backend/Go.hs @@ -26,14 +26,17 @@ func name arg returnType body = <+> returnType <+> if isEmpty body then empty else block body -var :: Name -> Expr Mono.Type -> Doc -var name expr = +varWithType :: Name -> Mono.Type -> Expr Mono.Type -> Doc +varWithType name mt expr = text "var" <+> safeName name - <+> codegenType (typeOf expr) + <+> codegenType mt <+> equals <+> codegenExpr expr +var :: Name -> Expr Mono.Type -> Doc +var name expr = varWithType name (typeOf expr) expr + return' :: Expr Mono.Type -> Doc return' e@(Application _ _ t) | t == Mono.typeUnit = codegenExpr e $+$ text "return" @@ -140,21 +143,25 @@ codegenExpr (If condExpr thenExpr elseExpr t) = codegenExpr (Slice exprs t) = codegenType t <> braces (hcat (punctuate (comma <+> space) (map codegenExpr exprs))) -codegenTopLevel :: Name -> Expr Mono.Type -> Doc -codegenTopLevel name (NoArgFn body (Mono.TNoArgFn r)) = +codegenTopLevel :: Name -> Mono.Type -> Expr Mono.Type -> Doc +codegenTopLevel name (Mono.TNoArgFn r) (NoArgFn body _) = func (safeName name) empty (codegenType r) (return' body) -codegenTopLevel name (Fn a body (Mono.TFn d r)) = +codegenTopLevel name _ (NoArgFn body (Mono.TNoArgFn r)) = + func (safeName name) empty (codegenType r) (return' body) +codegenTopLevel name (Mono.TFn d r) (Fn a body _) = + func (safeName name) (funcArg a d) (codegenType r) (return' body) +codegenTopLevel name _ (Fn a body (Mono.TFn d r)) = func (safeName name) (funcArg a d) (codegenType r) (return' body) -codegenTopLevel name expr = - var name expr +codegenTopLevel name t expr = + varWithType name t expr codegenInstance :: InstantiatedDefinition -> Doc codegenInstance (InstantiatedDefinition name expr) = - codegenTopLevel name expr + codegenTopLevel name (typeOf expr) expr codegenMonomorphed :: MonomorphedDefinition -> Doc -codegenMonomorphed (MonomorphedDefinition name expr) = - codegenTopLevel name expr +codegenMonomorphed (MonomorphedDefinition name mt expr) = + codegenTopLevel name mt expr codegenImport :: Import -> Doc codegenImport (Import name) = diff --git a/src/oden/Oden/Compiler.hs b/src/oden/Oden/Compiler.hs index d0ef1a8..b0c842c 100644 --- a/src/oden/Oden/Compiler.hs +++ b/src/oden/Oden/Compiler.hs @@ -13,7 +13,7 @@ import Oden.Scope as Scope import qualified Oden.Type.Monomorphic as Mono import qualified Oden.Type.Polymorphic as Poly -data MonomorphedDefinition = MonomorphedDefinition Name (Core.Expr Mono.Type) +data MonomorphedDefinition = MonomorphedDefinition Name Mono.Type (Core.Expr Mono.Type) deriving (Show, Eq, Ord) data InstantiatedDefinition = @@ -198,11 +198,13 @@ unwrapLetInstances [] body = body unwrapLetInstances (LetInstance mn me:is) body = Core.Let mn me (unwrapLetInstances is body) (Core.typeOf body) monomorphDefinition :: Core.Definition -> Monomorph () -monomorphDefinition d@(Core.Definition name (s, expr)) = do +monomorphDefinition d@(Core.Definition name (Poly.Forall _ st, expr)) = do addToScope d - unless (Poly.isPolymorphic s) $ do - mExpr <- monomorph expr - addMonomorphed name (MonomorphedDefinition name mExpr) + case Poly.toMonomorphic st of + Left _ -> return () + Right mt -> do + mExpr <- monomorph expr + addMonomorphed name (MonomorphedDefinition name mt mExpr) monomorphPackage :: Scope -> Core.Package -> Either CompilationError CompiledPackage monomorphPackage scope' (Core.Package name imports definitions) = do diff --git a/src/oden/Oden/Infer.hs b/src/oden/Oden/Infer.hs index e9e2ba3..77801ba 100644 --- a/src/oden/Oden/Infer.hs +++ b/src/oden/Oden/Infer.hs @@ -1,5 +1,4 @@ {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeSynonymInstances #-} module Oden.Infer ( @@ -12,18 +11,22 @@ module Oden.Infer ( constraintsExpr ) where +import Control.Arrow (left) import Control.Monad.Except import Control.Monad.Identity import Control.Monad.RWS hiding ((<>)) import Data.List (nub) import qualified Data.Map as Map +import Data.Maybe import qualified Data.Set as Set import qualified Oden.Core as Core import qualified Oden.Core.Untyped as Untyped import Oden.Env as Env import Oden.Identifier +import Oden.Infer.Substitution +import Oden.Infer.Subsumption import Oden.Type.Polymorphic ------------------------------------------------------------------------------- @@ -53,32 +56,6 @@ type Unifier = (Subst, [Constraint]) -- | Constraint solver monad type Solve a = ExceptT TypeError Identity a -newtype Subst = Subst (Map.Map TVar Type) - deriving (Eq, Ord, Show, Monoid) - -class FTV a => Substitutable a where - apply :: Subst -> a -> a - -instance Substitutable Type where - apply _ TAny = TAny - apply _ (TCon a) = TCon a - apply (Subst s) t@(TVar a) = Map.findWithDefault t a s - apply s (TNoArgFn t) = TNoArgFn (apply s t) - apply s (t1 `TFn` t2) = apply s t1 `TFn` apply s t2 - apply s (TUncurriedFn as r) = TUncurriedFn (map (apply s) as) (apply s r) - apply s (TVariadicFn as v r) = TVariadicFn (map (apply s) as) (apply s v) (apply s r) - apply s (TSlice t) = TSlice (apply s t) - - -instance Substitutable Scheme where - apply (Subst s) (Forall as t) = Forall as $ apply s' t - where s' = Subst $ foldr Map.delete s as - -instance FTV Core.CanonicalExpr where - ftv (sc, expr) = ftv sc `Set.union` ftv expr - -instance Substitutable Core.CanonicalExpr where - apply s (sc, expr) = (apply s sc, apply s expr) instance FTV Constraint where ftv (t1, t2) = ftv t1 `Set.union` ftv t2 @@ -86,30 +63,12 @@ instance FTV Constraint where instance Substitutable Constraint where apply s (t1, t2) = (apply s t1, apply s t2) -instance Substitutable a => Substitutable [a] where - apply = map . apply - instance FTV Env where ftv (TypeEnv env) = ftv $ Map.elems env instance Substitutable Env where apply s (TypeEnv env) = TypeEnv $ Map.map (apply s) env -instance FTV (Core.Expr Type) where - ftv = ftv . Core.typeOf - -instance Substitutable (Core.Expr Type) where - apply s (Core.Symbol x t) = Core.Symbol x (apply s t) - apply s (Core.Application f p t) = Core.Application (apply s f) (apply s p) (apply s t) - apply s (Core.NoArgApplication f t) = Core.NoArgApplication (apply s f) (apply s t) - apply s (Core.UncurriedFnApplication f p t) = Core.UncurriedFnApplication (apply s f) (apply s p) (apply s t) - apply s (Core.Fn x b t) = Core.Fn x (apply s b) (apply s t) - apply s (Core.NoArgFn b t) = Core.NoArgFn (apply s b) (apply s t) - apply s (Core.Let x e b t) = Core.Let x (apply s e) (apply s b) (apply s t) - apply s (Core.Literal l t) = Core.Literal l (apply s t) - apply s (Core.If c tb fb t) = Core.If (apply s c) (apply s tb) (apply s fb) (apply s t) - apply s (Core.Slice es t) = Core.Slice (apply s es) (apply s t) - data TypeError = UnificationFail Type Type | InfiniteType TVar Type @@ -117,6 +76,7 @@ data TypeError | Ambigious [Constraint] | UnificationMismatch [Type] [Type] | ArgumentCountMismatch [Type] [Type] + | TypeSignatureSubsumptionError Name SubsumptionError deriving (Show, Eq) ------------------------------------------------------------------------------- @@ -268,13 +228,11 @@ infer expr = case expr of return (Core.Slice tes (TSlice tv)) inferDef :: Untyped.Definition -> Infer Core.Definition -inferDef (Untyped.Definition name (Just sc) expr) = do - te <- inEnv (Unqualified name, sc) (infer expr) - return (Core.Definition name (sc, te)) -inferDef (Untyped.Definition name Nothing expr) = do +inferDef (Untyped.Definition name s expr) = do tv <- fresh env <- ask - te <- inEnv (Unqualified name, Forall [] tv) (infer expr) + let recType = fromMaybe (Forall [] tv) s + te <- inEnv (Unqualified name, recType) (infer expr) return (Core.Definition name (generalize env te)) inferDefinition :: Env -> Untyped.Definition -> Either TypeError Core.Definition @@ -282,10 +240,12 @@ inferDefinition env def@(Untyped.Definition _ Nothing _) = do (Core.Definition name (_, te), cs) <- runInfer env (inferDef def) subst <- runSolve cs return $ Core.Definition name (closeOver (apply subst te)) -inferDefinition env def = do +inferDefinition env def@(Untyped.Definition _ (Just st) _) = do (Core.Definition name ce, cs) <- runInfer env (inferDef def) subst <- runSolve cs - return $ Core.Definition name (apply subst ce) + let (Forall _ _, substExpr) = apply subst ce + ce' <- left (TypeSignatureSubsumptionError name) $ subsumeTypeSignature st substExpr + return $ Core.Definition name ce' inferPackage :: Env -> Untyped.Package -> Either TypeError (Core.Package, Env) inferPackage env (Untyped.Package name imports defs) = do @@ -327,14 +287,6 @@ normalize (Forall _ body, te) = (Forall (map snd ord) (normtype body), te) -- Constraint Solver ------------------------------------------------------------------------------- --- | The empty substitution -emptySubst :: Subst -emptySubst = mempty - --- | Compose substitutions -compose :: Subst -> Subst -> Subst -(Subst s1) `compose` (Subst s2) = Subst $ Map.map (apply (Subst s1)) s2 `Map.union` s1 - -- | Run the constraint solver runSolve :: [Constraint] -> Either TypeError Subst runSolve cs = runIdentity $ runExceptT $ solver st @@ -353,7 +305,6 @@ unifies t1 t2 | t1 == t2 = return emptySubst unifies TAny (TVar v) = v `bind` TAny unifies (TVar v) TAny = v `bind` TAny unifies TAny _ = return emptySubst -unifies _ TAny = return emptySubst unifies (TVar v) t = v `bind` t unifies t (TVar v) = v `bind` t unifies (TFn t1 t2) (TFn t3 t4) = unifyMany [t1, t2] [t3, t4] diff --git a/src/oden/Oden/Infer/Substitution.hs b/src/oden/Oden/Infer/Substitution.hs new file mode 100644 index 0000000..1ae163a --- /dev/null +++ b/src/oden/Oden/Infer/Substitution.hs @@ -0,0 +1,63 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE TypeSynonymInstances #-} +module Oden.Infer.Substitution where + +import Oden.Type.Polymorphic +import Oden.Core as Core + +import qualified Data.Set as Set +import qualified Data.Map as Map + +newtype Subst = Subst (Map.Map TVar Type) + deriving (Eq, Ord, Show, Monoid) + +-- | The empty substitution +emptySubst :: Subst +emptySubst = mempty + +-- | Compose substitutions +compose :: Subst -> Subst -> Subst +(Subst s1) `compose` (Subst s2) = Subst $ Map.map (apply (Subst s1)) s2 `Map.union` s1 + +class FTV a => Substitutable a where + apply :: Subst -> a -> a + +instance Substitutable Type where + apply _ TAny = TAny + apply _ (TCon a) = TCon a + apply (Subst s) t@(TVar a) = Map.findWithDefault t a s + apply s (TNoArgFn t) = TNoArgFn (apply s t) + apply s (t1 `TFn` t2) = apply s t1 `TFn` apply s t2 + apply s (TUncurriedFn as r) = TUncurriedFn (map (apply s) as) (apply s r) + apply s (TVariadicFn as v r) = TVariadicFn (map (apply s) as) (apply s v) (apply s r) + apply s (TSlice t) = TSlice (apply s t) + + +instance Substitutable Scheme where + apply (Subst s) (Forall as t) = Forall as $ apply s' t + where s' = Subst $ foldr Map.delete s as + +instance FTV Core.CanonicalExpr where + ftv (sc, expr) = ftv sc `Set.union` ftv expr + +instance Substitutable Core.CanonicalExpr where + apply s (sc, expr) = (apply s sc, apply s expr) + +instance FTV (Core.Expr Type) where + ftv = ftv . Core.typeOf + +instance Substitutable (Core.Expr Type) where + apply s (Core.Symbol x t) = Core.Symbol x (apply s t) + apply s (Core.Application f p t) = Core.Application (apply s f) (apply s p) (apply s t) + apply s (Core.NoArgApplication f t) = Core.NoArgApplication (apply s f) (apply s t) + apply s (Core.UncurriedFnApplication f p t) = Core.UncurriedFnApplication (apply s f) (apply s p) (apply s t) + apply s (Core.Fn x b t) = Core.Fn x (apply s b) (apply s t) + apply s (Core.NoArgFn b t) = Core.NoArgFn (apply s b) (apply s t) + apply s (Core.Let x e b t) = Core.Let x (apply s e) (apply s b) (apply s t) + apply s (Core.Literal l t) = Core.Literal l (apply s t) + apply s (Core.If c tb fb t) = Core.If (apply s c) (apply s tb) (apply s fb) (apply s t) + apply s (Core.Slice es t) = Core.Slice (apply s es) (apply s t) + +instance Substitutable a => Substitutable [a] where + apply = map . apply diff --git a/src/oden/Oden/Infer/Subsumption.hs b/src/oden/Oden/Infer/Subsumption.hs new file mode 100644 index 0000000..dda7512 --- /dev/null +++ b/src/oden/Oden/Infer/Subsumption.hs @@ -0,0 +1,69 @@ +module Oden.Infer.Subsumption ( + Subsuming, + SubsumptionError(..), + subsume, + subsumeTypeSignature +) where + +import Oden.Type.Polymorphic +import Oden.Core as Core +import Oden.Infer.Substitution + +import qualified Data.Map as Map + +data SubsumptionError = SubsumptionError Type Type + deriving (Show, Eq) + +class Subsuming s where + subsume :: s -> s -> Either SubsumptionError s + +subsumeTypeSignature :: Scheme -> Core.Expr Type -> Either SubsumptionError Core.CanonicalExpr +subsumeTypeSignature s@(Forall _ st) expr = do + subst <- getSubst st (Core.typeOf expr) + return (s, apply subst expr) + where + getSubst :: Type -> Type -> Either SubsumptionError Subst + getSubst t (TVar tv) = return (Subst (Map.singleton tv t)) + getSubst (TFn a1 r1) (TFn a2 r2) = do + a <- getSubst a1 a2 + r <- getSubst r1 r2 + return (a `compose` r) + getSubst (TNoArgFn r1) (TNoArgFn r2) = getSubst r1 r2 + getSubst (TUncurriedFn a1 r1) (TUncurriedFn a2 r2) = do + as <- mapM (uncurry getSubst) ((r1, r2) : zip a1 a2) + return (foldl compose emptySubst as) + getSubst (TVariadicFn a1 v1 r1) (TVariadicFn a2 v2 r2) = do + as <- mapM (uncurry getSubst) ((r1, r2) : (v1, v2) : zip a1 a2) + return (foldl compose emptySubst as) + getSubst TAny _ = return emptySubst + getSubst t1 t2 + | t1 == t2 = return emptySubst + | otherwise = Left (SubsumptionError t1 t2) + +instance Subsuming Type where + TAny `subsume` TAny = Right TAny + t `subsume` TAny = Left (SubsumptionError t TAny) + TAny `subsume` _ = Right TAny + t1@(TNoArgFn at1) `subsume` (TNoArgFn at2) = do + _ <- at1 `subsume` at2 + return t1 + t1@(TFn at1 rt1) `subsume` (TFn at2 rt2) = do + _ <- at1 `subsume` at2 + _ <- rt1 `subsume` rt2 + return t1 + t1@(TUncurriedFn ats1 rt1) `subsume` (TUncurriedFn ats2 rt2) = do + mapM_ (uncurry subsume) (zip ats1 ats2) + _ <- rt1 `subsume` rt2 + return t1 + t1@(TVariadicFn ats1 vt1 rt1) `subsume` (TVariadicFn ats2 vt2 rt2) = do + mapM_ (uncurry subsume) (zip ats1 ats2) + _ <- vt1 `subsume` vt2 + _ <- rt1 `subsume` rt2 + return t1 + t1@(TSlice st1) `subsume` (TSlice st2) = do + _ <- st1 `subsume` st2 + return t1 + t1 `subsume` t2 + | t1 == t2 = Right t1 + | otherwise = Left (SubsumptionError t1 t2) + diff --git a/src/oden/Oden/Output/Infer.hs b/src/oden/Oden/Output/Infer.hs index c13a7b3..5084356 100644 --- a/src/oden/Oden/Output/Infer.hs +++ b/src/oden/Oden/Output/Infer.hs @@ -4,16 +4,18 @@ import Text.PrettyPrint import Oden.Output import Oden.Infer +import Oden.Infer.Subsumption instance OdenOutput TypeError where outputType _ = Error - name (UnificationFail _ _) = "Infer.UnificationFail" - name (InfiniteType _ _) = "Infer.InfiniteType" - name (NotInScope _) = "Infer.NotInScope" - name (Ambigious _) = "Infer.Ambigious" - name (UnificationMismatch _ _) = "Infer.UnificationMismatch" - name (ArgumentCountMismatch _ _) = "Infer.ArgumentCountMismatch" + name (UnificationFail _ _) = "Infer.UnificationFail" + name (InfiniteType _ _) = "Infer.InfiniteType" + name (NotInScope _) = "Infer.NotInScope" + name (Ambigious _) = "Infer.Ambigious" + name (UnificationMismatch _ _) = "Infer.UnificationMismatch" + name (ArgumentCountMismatch _ _) = "Infer.ArgumentCountMismatch" + name (TypeSignatureSubsumptionError _ (SubsumptionError _ _)) = "Infer.TypeSignatureSubsumptionError" header (UnificationFail t1 t2) s = text "Cannot unify types" <+> code s t1 <+> text "and" <+> code s t2 header (InfiniteType _ _) _ = text "Cannot construct an infinite type" @@ -24,6 +26,9 @@ instance OdenOutput TypeError where text "Function is applied to too few arguments" header (ArgumentCountMismatch _ _) _ = text "Function is applied to too many arguments" + header (TypeSignatureSubsumptionError n SubsumptionError{}) s = + text "Type signature for" <+> strCode s n + <+> text "does not subsume the type of the definition" details (UnificationFail _ _) _ = empty details (InfiniteType v t) s = code s v <+> equals <+> code s t @@ -36,4 +41,5 @@ instance OdenOutput TypeError where details (ArgumentCountMismatch as1 as2) s = text "Expected:" <+> vcat (map (code s) as1) $+$ text "Actual:" <+> vcat (map (code s) as2) - + details (TypeSignatureSubsumptionError _ (SubsumptionError t1 t2)) s = + text "Type" <+> code s t1 <+> text "does not subsume" <+> code s t2 diff --git a/test/Oden/CompilerSpec.hs b/test/Oden/CompilerSpec.hs index 1b3cd16..840188f 100644 --- a/test/Oden/CompilerSpec.hs +++ b/test/Oden/CompilerSpec.hs @@ -62,6 +62,7 @@ usingIdentityMonomorphed :: MonomorphedDefinition usingIdentityMonomorphed = MonomorphedDefinition "using-identity" + Mono.typeInt (Core.Application (Core.Symbol (Unqualified "identity_inst_int_to_int") (Mono.TFn Mono.typeInt Mono.typeInt)) (Core.Literal (Core.Int 1) Mono.typeInt) Mono.typeInt) @@ -81,6 +82,7 @@ usingIdentity2Monomorphed :: MonomorphedDefinition usingIdentity2Monomorphed = MonomorphedDefinition "using-identity2" + Mono.typeInt (Core.Application (Core.Symbol (Unqualified "identity2_inst_int_to_int") (Mono.TFn Mono.typeInt Mono.typeInt)) (Core.Literal (Core.Int 1) Mono.typeInt) Mono.typeInt) @@ -89,6 +91,7 @@ letBoundIdentityMonomorphed :: MonomorphedDefinition letBoundIdentityMonomorphed = MonomorphedDefinition "let-bound-identity" + Mono.typeInt (Core.Let "identity_inst_int_to_int" (Core.Fn "x" (Core.Symbol (Unqualified "x") Mono.typeInt) (Mono.TFn Mono.typeInt Mono.typeInt)) (Core.Application (Core.Symbol (Unqualified "identity_inst_int_to_int") (Mono.TFn Mono.typeInt Mono.typeInt)) (Core.Literal (Core.Int 1) Mono.typeInt) @@ -126,6 +129,7 @@ sliceLenMonomorphed :: MonomorphedDefinition sliceLenMonomorphed = MonomorphedDefinition "slice-len" + Mono.typeInt (Core.UncurriedFnApplication (Core.Symbol (Unqualified "len") (Mono.TUncurriedFn [Mono.TSlice Mono.typeBool] Mono.typeInt)) [Core.Slice [Core.Literal (Core.Bool True) Mono.typeBool] (Mono.TSlice Mono.typeBool)] @@ -150,6 +154,7 @@ letWithShadowingMonomorphed :: MonomorphedDefinition letWithShadowingMonomorphed = MonomorphedDefinition "let-with-shadowing" + Mono.typeInt (Core.Let "x" (Core.Literal (Core.Int 1) Mono.typeInt) diff --git a/test/Oden/Infer/SubsumptionSpec.hs b/test/Oden/Infer/SubsumptionSpec.hs new file mode 100644 index 0000000..ae2cae1 --- /dev/null +++ b/test/Oden/Infer/SubsumptionSpec.hs @@ -0,0 +1,53 @@ +module Oden.Infer.SubsumptionSpec where + +import Test.Hspec + +import Oden.Infer.Subsumption +import Oden.Type.Polymorphic + +import Oden.Assertions + + +tvarA :: Type +tvarA = TVar (TV "a") + +tvarB :: Type +tvarB = TVar (TV "b") + +spec :: Spec +spec = + describe "subsume" $ do + it "any subsume any" $ + TAny `subsume` TAny + `shouldSucceedWith` + TAny + it "any subsume int" $ + TAny `subsume` typeInt + `shouldSucceedWith` + TAny + it "int does not subsume any" $ + shouldFail (typeInt `subsume` TAny) + it "tvar does not subsume any" $ + shouldFail (tvarA `subsume` TAny) + it "any subsume tvar" $ + TAny `subsume` tvarA + `shouldSucceedWith` + TAny + it "tvar subsume same tvar" $ + tvarA `subsume` tvarA + `shouldSucceedWith` + tvarA + it "tvar does not subsume other tvars" $ + shouldFail (tvarA `subsume` tvarB) + it "tcon subsume same tcon" $ + typeInt `subsume` typeInt + `shouldSucceedWith` + typeInt + it "tcon does not subsume other tcons" $ + shouldFail (typeInt `subsume` typeBool) + it "TFn of TVars subsume same TFn" $ + TFn tvarA tvarA `subsume` TFn tvarA tvarA + `shouldSucceedWith` + TFn tvarA tvarA + it "TVar does not subsume TFn" $ + shouldFail (tvarA `subsume` TFn tvarB tvarB) diff --git a/test/Oden/InferSpec.hs b/test/Oden/InferSpec.hs index 11c7db0..eb8671a 100644 --- a/test/Oden/InferSpec.hs +++ b/test/Oden/InferSpec.hs @@ -33,6 +33,12 @@ predefAndIdentityAny = predef `extend` (Unqualified "identity", booleanOp :: Type booleanOp = typeBool `TFn` (typeBool `TFn` typeBool) +tvA :: TVar +tvA = TV "a" + +tvarA :: Type +tvarA = TVar (TV "a") + countToZero :: Untyped.Expr countToZero = Untyped.Fn @@ -87,6 +93,35 @@ countToZeroTyped = typeInt) intToInt) +twiceUntyped :: Untyped.Expr +twiceUntyped = + Untyped.Fn + "f" + (Untyped.Fn + "x" + (Untyped.Application + (Untyped.Symbol (Unqualified "f")) + [Untyped.Application + (Untyped.Symbol (Unqualified "f")) + [Untyped.Symbol (Unqualified "x")]])) + +twiceTyped :: Core.Definition +twiceTyped = + Core.Definition "twice" (Forall [tvA] (TFn (TFn tvarA tvarA) (TFn tvarA tvarA)), + Core.Fn + "f" + (Core.Fn + "x" + (Core.Application + (Core.Symbol (Unqualified "f") (TFn tvarA tvarA)) + (Core.Application + (Core.Symbol (Unqualified "f") (TFn tvarA tvarA)) + (Core.Symbol (Unqualified "x") tvarA) + tvarA) + tvarA) + (TFn tvarA tvarA)) + (TFn (TFn tvarA tvarA) (TFn tvarA tvarA))) + spec :: Spec spec = do describe "inferExpr" $ do @@ -298,6 +333,33 @@ spec = do Core.Definition "id" (Forall [TV "a"] (TFn (TVar (TV "a")) (TVar (TV "a"))), Core.Fn "x" (Core.Symbol (Unqualified "x") (TVar (TV "a"))) (TFn (TVar (TV "a")) (TVar (TV "a")))) + it "fails when specified type signature does not unify" $ + shouldFail $ + inferDefinition empty (Untyped.Definition "some-number" + (Just $ Forall [] typeBool) + (Untyped.Literal (Untyped.Int 1))) + + it "subsumes int with any" $ + inferDefinition empty (Untyped.Definition "some-number" + (Just $ Forall [] TAny) + (Untyped.Literal (Untyped.Int 1))) + `shouldSucceedWith` + Core.Definition "some-number" (Forall [] TAny, Core.Literal (Core.Int 1) typeInt) + + + it "infers twice function with correct type signature" $ + inferDefinition empty (Untyped.Definition "twice" + (Just $ Forall [tvA] (TFn (TFn tvarA tvarA) (TFn tvarA tvarA))) + twiceUntyped) + `shouldSucceedWith` + twiceTyped + + it "fails on twice function with incorrect type signature" $ + shouldFail $ + inferDefinition empty (Untyped.Definition "twice" + (Just $ Forall [tvA] (TFn tvarA tvarA)) + twiceUntyped) + it "infers recursive definition" $ inferDefinition predef (Untyped.Definition "f" (Just $ Forall [] intToInt) countToZero) `shouldSucceedWith` @@ -307,3 +369,7 @@ spec = do inferDefinition predef (Untyped.Definition "f" Nothing countToZero) `shouldSucceedWith` countToZeroTyped + + it "fails on recursive with incorrect signature" $ + shouldFail $ + inferDefinition predef (Untyped.Definition "f" (Just $ Forall [] (TFn typeInt TAny)) countToZero)