From dafab18d10f23367a871903f969008efd4194431 Mon Sep 17 00:00:00 2001 From: MatthewDaggitt Date: Thu, 12 Sep 2024 13:07:37 +0800 Subject: [PATCH 1/3] Flipped order of if-lifting --- vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs | 12 ++++---- .../src/Vehicle/Compile/Boolean/Unblock.hs | 28 ++++++++++--------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs b/vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs index 0f0409470..d080255e9 100644 --- a/vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs +++ b/vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs @@ -18,18 +18,18 @@ import Vehicle.Data.Code.Value liftIf :: (Monad m) => - (WHNFValue Builtin -> m (WHNFValue Builtin)) -> WHNFValue Builtin -> + (WHNFValue Builtin -> m (WHNFValue Builtin)) -> m (WHNFValue Builtin) -liftIf k (IIf t cond e1 e2) = IIf t cond <$> liftIf k e1 <*> liftIf k e2 -liftIf k e = k e +liftIf (IIf t cond e1 e2) k = IIf t cond <$> liftIf e1 k <*> liftIf e2 k +liftIf e k = k e liftIfArg :: (Monad m) => - (WHNFArg Builtin -> m (WHNFValue Builtin)) -> WHNFArg Builtin -> + (WHNFArg Builtin -> m (WHNFValue Builtin)) -> m (WHNFValue Builtin) -liftIfArg k (Arg p v r e) = liftIf (k . Arg p v r) e +liftIfArg (Arg p v r e) k = liftIf e (k . Arg p v r) liftIfSpine :: (Monad m) => @@ -37,7 +37,7 @@ liftIfSpine :: (WHNFSpine Builtin -> m (WHNFValue Builtin)) -> m (WHNFValue Builtin) liftIfSpine [] k = k [] -liftIfSpine (x : xs) k = liftIfArg (\a -> liftIfSpine xs (\as -> k (a : as))) x +liftIfSpine (x : xs) k = liftIfArg x (\a -> liftIfSpine xs (\as -> k (a : as))) unfoldIf :: (Monad m, MonadFreeContext Builtin m) => diff --git a/vehicle/src/Vehicle/Compile/Boolean/Unblock.hs b/vehicle/src/Vehicle/Compile/Boolean/Unblock.hs index 6d733fed3..e712e8520 100644 --- a/vehicle/src/Vehicle/Compile/Boolean/Unblock.hs +++ b/vehicle/src/Vehicle/Compile/Boolean/Unblock.hs @@ -80,9 +80,11 @@ unblockNonVector actions expr = case expr of IIf {} -> return expr IForall {} -> return expr IExists {} -> return expr + -- Can be removed? IVectorEqualFull spine@(IVecEqSpine t _ _ _ _ _) | isRatTensor (argExpr t) -> return expr | otherwise -> appHiddenStdlibDef StdEqualsVector spine + -- Can be removed? IVectorNotEqualFull spine@(IVecEqSpine t _ _ _ _ _) | isRatTensor (argExpr t) -> return expr | otherwise -> appHiddenStdlibDef StdNotEqualsVector spine @@ -137,8 +139,8 @@ unblockNonVectorOp2 :: unblockNonVectorOp2 actions b evalOp2 x y implArgs = do x' <- unblockNonVector actions x y' <- unblockNonVector actions y - flip liftIf x' $ \x'' -> - flip liftIf y' $ \y'' -> + liftIf x' $ \x'' -> + liftIf y' $ \y'' -> forceEvalSimple b evalOp2 (implArgs <> [explicit x'', explicit y'']) unblockVectorOp2 :: @@ -164,7 +166,7 @@ unblockFoldVector :: m (WHNFValue Builtin) unblockFoldVector actions t1 t2 n f e xs = do xs' <- unblockVector actions True xs - flip liftIf xs' $ \xs'' -> + liftIf xs' $ \xs'' -> forceEval FoldVector (evalFoldVector normaliseApp) [t1, t2, n, explicit f, explicit e, explicit xs''] unblockMapVector :: @@ -178,7 +180,7 @@ unblockMapVector :: m (WHNFValue Builtin) unblockMapVector actions t1 t2 n f xs = do xs' <- unblockVector actions True xs - flip liftIf xs' $ \xs'' -> + liftIf xs' $ \xs'' -> forceEval MapVector (evalMapVector normaliseApp) [t1, t2, n, explicit f, explicit xs''] unblockZipWith :: @@ -195,8 +197,8 @@ unblockZipWith :: unblockZipWith actions t1 t2 t3 n f xs ys = do xs' <- unblockVector actions True xs ys' <- unblockVector actions True ys - flip liftIf xs' $ \xs'' -> - flip liftIf ys' $ \ys'' -> + liftIf xs' $ \xs'' -> + liftIf ys' $ \ys'' -> forceEval ZipWithVector (evalZipWith normaliseApp) [t1, t2, t3, n, explicit f, explicit xs'', explicit ys''] unblockAt :: @@ -211,7 +213,7 @@ unblockAt :: unblockAt actions t n c i = case c of IVecLiteral {} -> do i' <- unblockNonVector actions i - flip liftIf i' $ \i'' -> do + liftIf i' $ \i'' -> do forceEvalSimple At evalAt [t, n, explicit c, explicit i''] IMapVector _ _ t2 f xs -> appAt f [(t2, n, xs)] i IZipWithVector t1 t2 _ _ f xs ys -> appAt f [(t1, n, xs), (t2, n, ys)] i @@ -247,7 +249,7 @@ unblockIndices :: m (WHNFValue Builtin) unblockIndices actions n = do n' <- unblockNonVector actions n - flip liftIf n' $ \n'' -> + liftIf n' $ \n'' -> forceEvalSimple Indices (evalIndices (VBuiltinFunction Indices)) (explicit <$> [n'']) forceEval :: @@ -366,8 +368,8 @@ purifyRatOp2 :: purifyRatOp2 actions mkOp evalOp2 x y = do x' <- purify actions x y' <- purify actions y - flip liftIf x' $ \x'' -> - flip liftIf y' $ \y'' -> + liftIf x' $ \x'' -> + liftIf y' $ \y'' -> return $ evalOp2 (mkOp x'' y'') [explicit x'', explicit y''] purifyNegRat :: @@ -377,7 +379,7 @@ purifyNegRat :: m (WHNFValue Builtin) purifyNegRat actions x = do x' <- purify actions x - flip liftIf x' $ \x'' -> + liftIf x' $ \x'' -> return $ evalNegRat (INeg NegRat x'') [explicit x''] traverseVectorOp2 :: @@ -391,8 +393,8 @@ traverseVectorOp2 :: traverseVectorOp2 f fn spinePrefix xs ys = do xs' <- f xs ys' <- f ys - flip liftIf xs' $ \xs'' -> - flip liftIf ys' $ \ys'' -> do + liftIf xs' $ \xs'' -> + liftIf ys' $ \ys'' -> do let newSpine = spinePrefix <> (Arg mempty Explicit Relevant <$> [xs'', ys'']) case (xs'', ys'') of (IVecLiteral {}, IVecLiteral {}) -> appHiddenStdlibDef fn newSpine From 6ad8065ac27eb30d61273d3bb0728cbc1ed63729 Mon Sep 17 00:00:00 2001 From: MatthewDaggitt Date: Thu, 12 Sep 2024 13:15:21 +0800 Subject: [PATCH 2/3] Simplified evalApp --- vehicle/src/Vehicle/Compile/Normalise/NBE.hs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vehicle/src/Vehicle/Compile/Normalise/NBE.hs b/vehicle/src/Vehicle/Compile/Normalise/NBE.hs index f6933fe59..247223d04 100644 --- a/vehicle/src/Vehicle/Compile/Normalise/NBE.hs +++ b/vehicle/src/Vehicle/Compile/Normalise/NBE.hs @@ -165,9 +165,7 @@ evalApp freeEnv fun args@(a : as) = do visibilityError currentPass (prettyVerbose fun) (prettyVerbose args) | otherwise -> do body' <- evalClosure freeEnv closure (binder, argExpr a) - case as of - [] -> return body' - (b : bs) -> evalApp freeEnv body' (b : bs) + evalApp freeEnv body' as VUniverse {} -> unexpectedExprError currentPass ("VUniverse" <+> prettyVerbose args) VPi {} -> unexpectedExprError currentPass ("VPi" <+> prettyVerbose args) From 349c855f1d5a47f1931e22fbd5c93d9b308e3020 Mon Sep 17 00:00:00 2001 From: MatthewDaggitt Date: Wed, 25 Sep 2024 12:53:47 +0800 Subject: [PATCH 3/3] Fix minor typos --- .../Queries/UserVariableElimination.hs | 9 +++---- .../Queries/UserVariableElimination/Core.hs | 26 +++++++++---------- .../EliminateExists.hs | 4 +-- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination.hs b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination.hs index ece6b8d20..a7dd9d002 100644 --- a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination.hs +++ b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination.hs @@ -73,21 +73,18 @@ eliminateUserVariables = go -- Mixed cases -- ----------------- -- In the next three cases, we can only fail to unblock these cases because - -- we can't evaluate networks applied to constant arguments. + -- we can't evaluate networks applied to constant arguments or because of if statements. -- -- (if (forall x . f x > 0) then x else 0) > 0 -- - -- When we have that ability then case can be turned to an error. - -- These cases can happen, e.g. + -- When we have the ability to evaluate networks then this case can be turned to a + -- call to purify.. INot {} -> compileUnquantifiedQuerySet expr IEqual {} -> compileUnquantifiedQuerySet expr INotEqual {} -> compileUnquantifiedQuerySet expr IOrder {} -> compileUnquantifiedQuerySet expr IVectorEqual {} -> compileUnquantifiedQuerySet expr IVectorNotEqual {} -> compileUnquantifiedQuerySet expr - -- This final case can only occur at all because - -- we can't evaluate networks applied to constant arguments. - -- When we have that ability we can replace it with an error. _ -> compileUnquantifiedQuerySet expr compileQuantifiedQuerySet :: diff --git a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs index 81a472532..78eb5b1d4 100644 --- a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs +++ b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs @@ -278,18 +278,6 @@ instance Pretty Assertion where RationalIneq ineq -> pretty ineq TensorEq eq -> pretty eq -checkTriviality :: Assertion -> MaybeTrivial Assertion -checkTriviality ass = case ass of - RationalEq RationalEquality {..} -> case isConstant rationalEqExpr of - Nothing -> NonTrivial ass - Just d -> Trivial (d == 0) - RationalIneq RationalInequality {..} -> case isConstant rationalIneqExpr of - Nothing -> NonTrivial ass - Just d -> Trivial ((if strictness == Strict then (<) else (<=)) d 0) - TensorEq TensorEquality {..} -> case isConstant tensorEqExpr of - Nothing -> NonTrivial ass - Just d -> Trivial (isZero d) - prettyAssertions :: [Assertion] -> Doc a prettyAssertions assertions = vsep (fmap pretty assertions) @@ -350,6 +338,18 @@ mapAssertionExprs ft fr ass = checkTriviality $ case ass of RationalEq RationalEquality {..} -> RationalEq $ RationalEquality $ fr rationalEqExpr RationalIneq RationalInequality {..} -> RationalIneq $ RationalInequality strictness (fr rationalIneqExpr) +checkTriviality :: Assertion -> MaybeTrivial Assertion +checkTriviality ass = case ass of + RationalEq RationalEquality {..} -> case isConstant rationalEqExpr of + Nothing -> NonTrivial ass + Just d -> Trivial (d == 0) + RationalIneq RationalInequality {..} -> case isConstant rationalIneqExpr of + Nothing -> NonTrivial ass + Just d -> Trivial ((if strictness == Strict then (<) else (<=)) d 0) + TensorEq TensorEquality {..} -> case isConstant tensorEqExpr of + Nothing -> NonTrivial ass + Just d -> Trivial (isZero d) + substituteTensorEq :: (OriginalUserVariable, LinearExpr TensorVariable RationalTensor) -> Map RationalVariable (LinearExpr RationalVariable Rational) -> @@ -514,7 +514,7 @@ lookupVarByLevel :: (MonadState GlobalCtx m) => Lv -> m Variable lookupVarByLevel lv = do GlobalCtx {..} <- get case LinkedHashMap.lookup lv globalBoundVarCtx of - Nothing -> developerError "Cannout find variable var" + Nothing -> developerError "Cannot find variable var" Just v -> return v getReducedVariableExprFor :: (MonadState GlobalCtx m) => Lv -> m (Maybe (WHNFValue QueryBuiltin)) diff --git a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/EliminateExists.hs b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/EliminateExists.hs index 2e94cf9ab..2f2044855 100644 --- a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/EliminateExists.hs +++ b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/EliminateExists.hs @@ -96,9 +96,9 @@ solveTensorVariable userTensorVar solutions = \case foldlM (solveExists fromRationalAssertion solveRationalVariable) initial userRationalVars Inequalities {} -> compilerDeveloperError $ - "When trying to solve rational variable" + "When trying to solve tensor variable" <+> quotePretty userTensorVar - <+> "found unexpected tensor inequalities." + <+> "found unexpected rational inequalities." -------------------------------------------------------------------------------- -- UserRationalVariables and equalities/constraints