Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactorings #845

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,18 @@ eliminateUserVariables = go
-- Mixed cases --
-----------------
-- In the next three cases, we can only fail to unblock these cases because
-- we can't evaluate networks applied to constant arguments.
-- we can't evaluate networks applied to constant arguments or because of if statements.
--
-- (if (forall x . f x > 0) then x else 0) > 0
--
-- When we have that ability then case can be turned to an error.
-- These cases can happen, e.g.
-- When we have the ability to evaluate networks then this case can be turned to a
-- call to purify..
INot {} -> compileUnquantifiedQuerySet expr
IEqual {} -> compileUnquantifiedQuerySet expr
INotEqual {} -> compileUnquantifiedQuerySet expr
IOrder {} -> compileUnquantifiedQuerySet expr
IVectorEqual {} -> compileUnquantifiedQuerySet expr
IVectorNotEqual {} -> compileUnquantifiedQuerySet expr
-- This final case can only occur at all because
-- we can't evaluate networks applied to constant arguments.
-- When we have that ability we can replace it with an error.
_ -> compileUnquantifiedQuerySet expr

compileQuantifiedQuerySet ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,18 +278,6 @@ instance Pretty Assertion where
RationalIneq ineq -> pretty ineq
TensorEq eq -> pretty eq

checkTriviality :: Assertion -> MaybeTrivial Assertion
checkTriviality ass = case ass of
RationalEq RationalEquality {..} -> case isConstant rationalEqExpr of
Nothing -> NonTrivial ass
Just d -> Trivial (d == 0)
RationalIneq RationalInequality {..} -> case isConstant rationalIneqExpr of
Nothing -> NonTrivial ass
Just d -> Trivial ((if strictness == Strict then (<) else (<=)) d 0)
TensorEq TensorEquality {..} -> case isConstant tensorEqExpr of
Nothing -> NonTrivial ass
Just d -> Trivial (isZero d)

prettyAssertions :: [Assertion] -> Doc a
prettyAssertions assertions =
vsep (fmap pretty assertions)
Expand Down Expand Up @@ -350,6 +338,18 @@ mapAssertionExprs ft fr ass = checkTriviality $ case ass of
RationalEq RationalEquality {..} -> RationalEq $ RationalEquality $ fr rationalEqExpr
RationalIneq RationalInequality {..} -> RationalIneq $ RationalInequality strictness (fr rationalIneqExpr)

checkTriviality :: Assertion -> MaybeTrivial Assertion
checkTriviality ass = case ass of
RationalEq RationalEquality {..} -> case isConstant rationalEqExpr of
Nothing -> NonTrivial ass
Just d -> Trivial (d == 0)
RationalIneq RationalInequality {..} -> case isConstant rationalIneqExpr of
Nothing -> NonTrivial ass
Just d -> Trivial ((if strictness == Strict then (<) else (<=)) d 0)
TensorEq TensorEquality {..} -> case isConstant tensorEqExpr of
Nothing -> NonTrivial ass
Just d -> Trivial (isZero d)

substituteTensorEq ::
(OriginalUserVariable, LinearExpr TensorVariable RationalTensor) ->
Map RationalVariable (LinearExpr RationalVariable Rational) ->
Expand Down Expand Up @@ -514,7 +514,7 @@ lookupVarByLevel :: (MonadState GlobalCtx m) => Lv -> m Variable
lookupVarByLevel lv = do
GlobalCtx {..} <- get
case LinkedHashMap.lookup lv globalBoundVarCtx of
Nothing -> developerError "Cannout find variable var"
Nothing -> developerError "Cannot find variable var"
Just v -> return v

getReducedVariableExprFor :: (MonadState GlobalCtx m) => Lv -> m (Maybe (WHNFValue QueryBuiltin))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ solveTensorVariable userTensorVar solutions = \case
foldlM (solveExists fromRationalAssertion solveRationalVariable) initial userRationalVars
Inequalities {} ->
compilerDeveloperError $
"When trying to solve rational variable"
"When trying to solve tensor variable"
<+> quotePretty userTensorVar
<+> "found unexpected tensor inequalities."
<+> "found unexpected rational inequalities."

--------------------------------------------------------------------------------
-- UserRationalVariables and equalities/constraints
Expand Down
12 changes: 6 additions & 6 deletions vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,26 @@ import Vehicle.Data.Code.Value

liftIf ::
(Monad m) =>
(WHNFValue Builtin -> m (WHNFValue Builtin)) ->
WHNFValue Builtin ->
(WHNFValue Builtin -> m (WHNFValue Builtin)) ->
m (WHNFValue Builtin)
liftIf k (IIf t cond e1 e2) = IIf t cond <$> liftIf k e1 <*> liftIf k e2
liftIf k e = k e
liftIf (IIf t cond e1 e2) k = IIf t cond <$> liftIf e1 k <*> liftIf e2 k
liftIf e k = k e

liftIfArg ::
(Monad m) =>
(WHNFArg Builtin -> m (WHNFValue Builtin)) ->
WHNFArg Builtin ->
(WHNFArg Builtin -> m (WHNFValue Builtin)) ->
m (WHNFValue Builtin)
liftIfArg k (Arg p v r e) = liftIf (k . Arg p v r) e
liftIfArg (Arg p v r e) k = liftIf e (k . Arg p v r)

liftIfSpine ::
(Monad m) =>
WHNFSpine Builtin ->
(WHNFSpine Builtin -> m (WHNFValue Builtin)) ->
m (WHNFValue Builtin)
liftIfSpine [] k = k []
liftIfSpine (x : xs) k = liftIfArg (\a -> liftIfSpine xs (\as -> k (a : as))) x
liftIfSpine (x : xs) k = liftIfArg x (\a -> liftIfSpine xs (\as -> k (a : as)))

unfoldIf ::
(Monad m, MonadFreeContext Builtin m) =>
Expand Down
28 changes: 15 additions & 13 deletions vehicle/src/Vehicle/Compile/Boolean/Unblock.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ unblockNonVector actions expr = case expr of
IIf {} -> return expr
IForall {} -> return expr
IExists {} -> return expr
-- Can be removed?
IVectorEqualFull spine@(IVecEqSpine t _ _ _ _ _)
| isRatTensor (argExpr t) -> return expr
| otherwise -> appHiddenStdlibDef StdEqualsVector spine
-- Can be removed?
IVectorNotEqualFull spine@(IVecEqSpine t _ _ _ _ _)
| isRatTensor (argExpr t) -> return expr
| otherwise -> appHiddenStdlibDef StdNotEqualsVector spine
Expand Down Expand Up @@ -137,8 +139,8 @@ unblockNonVectorOp2 ::
unblockNonVectorOp2 actions b evalOp2 x y implArgs = do
x' <- unblockNonVector actions x
y' <- unblockNonVector actions y
flip liftIf x' $ \x'' ->
flip liftIf y' $ \y'' ->
liftIf x' $ \x'' ->
liftIf y' $ \y'' ->
forceEvalSimple b evalOp2 (implArgs <> [explicit x'', explicit y''])

unblockVectorOp2 ::
Expand All @@ -164,7 +166,7 @@ unblockFoldVector ::
m (WHNFValue Builtin)
unblockFoldVector actions t1 t2 n f e xs = do
xs' <- unblockVector actions True xs
flip liftIf xs' $ \xs'' ->
liftIf xs' $ \xs'' ->
forceEval FoldVector (evalFoldVector normaliseApp) [t1, t2, n, explicit f, explicit e, explicit xs'']

unblockMapVector ::
Expand All @@ -178,7 +180,7 @@ unblockMapVector ::
m (WHNFValue Builtin)
unblockMapVector actions t1 t2 n f xs = do
xs' <- unblockVector actions True xs
flip liftIf xs' $ \xs'' ->
liftIf xs' $ \xs'' ->
forceEval MapVector (evalMapVector normaliseApp) [t1, t2, n, explicit f, explicit xs'']

unblockZipWith ::
Expand All @@ -195,8 +197,8 @@ unblockZipWith ::
unblockZipWith actions t1 t2 t3 n f xs ys = do
xs' <- unblockVector actions True xs
ys' <- unblockVector actions True ys
flip liftIf xs' $ \xs'' ->
flip liftIf ys' $ \ys'' ->
liftIf xs' $ \xs'' ->
liftIf ys' $ \ys'' ->
forceEval ZipWithVector (evalZipWith normaliseApp) [t1, t2, t3, n, explicit f, explicit xs'', explicit ys'']

unblockAt ::
Expand All @@ -211,7 +213,7 @@ unblockAt ::
unblockAt actions t n c i = case c of
IVecLiteral {} -> do
i' <- unblockNonVector actions i
flip liftIf i' $ \i'' -> do
liftIf i' $ \i'' -> do
forceEvalSimple At evalAt [t, n, explicit c, explicit i'']
IMapVector _ _ t2 f xs -> appAt f [(t2, n, xs)] i
IZipWithVector t1 t2 _ _ f xs ys -> appAt f [(t1, n, xs), (t2, n, ys)] i
Expand Down Expand Up @@ -247,7 +249,7 @@ unblockIndices ::
m (WHNFValue Builtin)
unblockIndices actions n = do
n' <- unblockNonVector actions n
flip liftIf n' $ \n'' ->
liftIf n' $ \n'' ->
forceEvalSimple Indices (evalIndices (VBuiltinFunction Indices)) (explicit <$> [n''])

forceEval ::
Expand Down Expand Up @@ -366,8 +368,8 @@ purifyRatOp2 ::
purifyRatOp2 actions mkOp evalOp2 x y = do
x' <- purify actions x
y' <- purify actions y
flip liftIf x' $ \x'' ->
flip liftIf y' $ \y'' ->
liftIf x' $ \x'' ->
liftIf y' $ \y'' ->
return $ evalOp2 (mkOp x'' y'') [explicit x'', explicit y'']

purifyNegRat ::
Expand All @@ -377,7 +379,7 @@ purifyNegRat ::
m (WHNFValue Builtin)
purifyNegRat actions x = do
x' <- purify actions x
flip liftIf x' $ \x'' ->
liftIf x' $ \x'' ->
return $ evalNegRat (INeg NegRat x'') [explicit x'']

traverseVectorOp2 ::
Expand All @@ -391,8 +393,8 @@ traverseVectorOp2 ::
traverseVectorOp2 f fn spinePrefix xs ys = do
xs' <- f xs
ys' <- f ys
flip liftIf xs' $ \xs'' ->
flip liftIf ys' $ \ys'' -> do
liftIf xs' $ \xs'' ->
liftIf ys' $ \ys'' -> do
let newSpine = spinePrefix <> (Arg mempty Explicit Relevant <$> [xs'', ys''])
case (xs'', ys'') of
(IVecLiteral {}, IVecLiteral {}) -> appHiddenStdlibDef fn newSpine
Expand Down
4 changes: 1 addition & 3 deletions vehicle/src/Vehicle/Compile/Normalise/NBE.hs
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ evalApp freeEnv fun args@(a : as) = do
visibilityError currentPass (prettyVerbose fun) (prettyVerbose args)
| otherwise -> do
body' <- evalClosure freeEnv closure (binder, argExpr a)
case as of
[] -> return body'
(b : bs) -> evalApp freeEnv body' (b : bs)
evalApp freeEnv body' as
VUniverse {} -> unexpectedExprError currentPass ("VUniverse" <+> prettyVerbose args)
VPi {} -> unexpectedExprError currentPass ("VPi" <+> prettyVerbose args)

Expand Down
Loading