From 25c73ee2560a2baf1204139cd7d9a6d9386fbedb Mon Sep 17 00:00:00 2001 From: vox9 <139348551+vox9@users.noreply.github.com> Date: Mon, 14 Oct 2024 13:08:29 +0200 Subject: [PATCH] Implement forward- and reverse mode AD in the interpreter (#2186) Co-authored-by: Troels Henriksen --- .github/workflows/main.yml | 4 +- CHANGELOG.md | 2 + futhark.cabal | 1 + src/Language/Futhark/Interpreter.hs | 317 +++++++++++++++----- src/Language/Futhark/Interpreter/AD.hs | 320 +++++++++++++++++++++ src/Language/Futhark/Interpreter/Values.hs | 45 ++- tests/ad/arr0.fut | 4 +- tests/ad/arr1.fut | 2 +- tests/ad/arr2.fut | 2 +- tests/ad/concat0.fut | 4 +- tests/ad/confusion0.fut | 2 +- tests/ad/consume0.fut | 2 +- tests/ad/consume1.fut | 2 +- tests/ad/consume2.fut | 2 +- tests/ad/consume3.fut | 4 +- tests/ad/consume4.fut | 6 +- tests/ad/consume5.fut | 2 +- tests/ad/consume6.fut | 4 +- tests/ad/conv0.fut | 4 +- tests/ad/conv1.fut | 4 +- tests/ad/fadd.fut | 2 +- tests/ad/fdiv.fut | 2 +- tests/ad/fmul.fut | 2 +- tests/ad/for0.fut | 8 +- tests/ad/for1.fut | 4 +- tests/ad/for2.fut | 4 +- tests/ad/for3.fut | 4 +- tests/ad/fwd/acc0.fut | 4 +- tests/ad/fwd/for0.fut | 8 +- tests/ad/fwd/for1.fut | 8 +- tests/ad/fwd/map0.fut | 4 +- tests/ad/fwd/red0.fut | 4 +- tests/ad/fwd/scatter0.fut | 2 +- tests/ad/fwd/while0.fut | 8 +- tests/ad/gather0.fut | 4 +- tests/ad/gather1.fut | 2 +- tests/ad/gather2.fut | 2 +- tests/ad/genred-opt/matmul.fut | 2 +- tests/ad/if0.fut | 8 +- tests/ad/if1.fut | 8 +- tests/ad/if2.fut | 8 +- tests/ad/imul.fut | 2 +- tests/ad/issue1577.fut | 2 +- tests/ad/issue1604.fut | 2 +- tests/ad/issue1879.fut | 4 +- tests/ad/lighthouse.fut | 2 +- tests/ad/map0.fut | 2 +- tests/ad/map1.fut | 2 +- tests/ad/map2.fut | 2 +- tests/ad/map3.fut | 2 +- tests/ad/map4.fut | 2 +- tests/ad/map5.fut | 2 +- tests/ad/map6.fut | 2 +- tests/ad/map7.fut | 2 +- tests/ad/matmul.fut | 2 +- tests/ad/maximum.fut | 4 +- tests/ad/minimum.fut | 4 +- tests/ad/minmax.fut | 2 +- tests/ad/negate.fut | 8 + tests/ad/nested0.fut | 2 +- tests/ad/nested1.fut | 2 +- tests/ad/nested2.fut | 2 +- tests/ad/nested3.fut | 2 +- tests/ad/nested4.fut | 2 +- tests/ad/not.fut | 8 + tests/ad/rearrange0.fut | 4 +- tests/ad/reduce0.fut | 2 +- tests/ad/reduce1.fut | 2 +- tests/ad/reduce2.fut | 2 +- tests/ad/reduce_by_index0.fut | 2 +- tests/ad/reducebyindex0.fut | 2 +- tests/ad/reducebyindex2.fut | 2 +- tests/ad/reducebyindex3.fut | 2 +- tests/ad/reducebyindex4.fut | 2 +- tests/ad/reducebyindex6.fut | 2 +- tests/ad/reducebyindexadd0.fut | 2 +- tests/ad/reducebyindexadd1.fut | 2 +- tests/ad/reducebyindexadd2.fut | 4 +- tests/ad/reducebyindexadd3.fut | 2 +- tests/ad/reducebyindexadd4.fut | 2 +- tests/ad/reducebyindexminmax0.fut | 10 +- tests/ad/reducebyindexminmax1.fut | 12 +- tests/ad/reducebyindexminmax10.fut | 2 +- tests/ad/reducebyindexminmax2.fut | 6 +- tests/ad/reducebyindexminmax3.fut | 6 +- tests/ad/reducebyindexminmax4.fut | 6 +- tests/ad/reducebyindexminmax5.fut | 2 +- tests/ad/reducebyindexminmax6.fut | 4 +- tests/ad/reducebyindexminmax9.fut | 2 +- tests/ad/reducebyindexmul0.fut | 2 +- tests/ad/reducebyindexmul1.fut | 4 +- tests/ad/reducebyindexmul2.fut | 8 +- tests/ad/reducebyindexmul3.fut | 2 +- tests/ad/reducebyindexmul4.fut | 2 +- tests/ad/reducebyindexvecmin0.fut | 2 +- tests/ad/reducebyindexvecmul0.fut | 2 +- tests/ad/reducemul0.fut | 2 +- tests/ad/reducemul1.fut | 2 +- tests/ad/reducemul2.fut | 2 +- tests/ad/reducemul3.fut | 2 +- tests/ad/reducemul4.fut | 2 +- tests/ad/reducevec0.fut | 2 +- tests/ad/reducevecmul0.fut | 2 +- tests/ad/reducevecmul1.fut | 2 +- tests/ad/reducevecmul2.fut | 2 +- tests/ad/reducevecmul3.fut | 2 +- tests/ad/replicate0.fut | 4 +- tests/ad/replicate1.fut | 4 +- tests/ad/replicate2.fut | 4 +- tests/ad/reshape0.fut | 4 +- tests/ad/rev_const.fut | 2 +- tests/ad/rev_unused.fut | 2 +- tests/ad/rotate0.fut | 4 +- tests/ad/scan0.fut | 2 +- tests/ad/scan1.fut | 2 +- tests/ad/scan2.fut | 2 +- tests/ad/scan3.fut | 2 +- tests/ad/scan4.fut | 2 +- tests/ad/scan5.fut | 2 +- tests/ad/scan6.fut | 4 +- tests/ad/scan7.fut | 4 +- tests/ad/scan8.fut | 2 +- tests/ad/scan9.fut | 2 +- tests/ad/scatter0.fut | 2 +- tests/ad/scatter1.fut | 2 +- tests/ad/sdf.fut | 6 +- tests/ad/stripmine0.fut | 6 +- tests/ad/stripmine1.fut | 4 +- tests/ad/stripmine2.fut | 4 +- tests/ad/stripmine3.fut | 4 +- tests/ad/sum.fut | 2 +- tests/ad/truedep0.fut | 4 +- tests/ad/while0.fut | 8 +- tests/ad/while1.fut | 8 +- 134 files changed, 836 insertions(+), 285 deletions(-) create mode 100644 src/Language/Futhark/Interpreter/AD.hs create mode 100644 tests/ad/negate.fut create mode 100644 tests/ad/not.fut diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9f1046f5b7..160ac49214 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -362,7 +362,7 @@ jobs: make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | - futhark test tests -c --no-terminal --backend=opencl --exclude=compiled --cache-extension=cache --pass-option=--build-option=-O0 --runner=tools/oclgrindrunner.sh + futhark test tests -c --no-terminal --backend=opencl --exclude=compiled --exclude=no_oclgrind --cache-extension=cache --pass-option=--build-option=-O0 --runner=tools/oclgrindrunner.sh test-pyoclgrind: runs-on: ubuntu-22.04 @@ -386,7 +386,7 @@ jobs: python -m venv virtualenv source virtualenv/bin/activate pip install 'numpy<2.0.0' pyopencl jsonschema - futhark test tests -c --no-terminal --backend=pyopencl --exclude=compiled --cache-extension=cache --pass-option=--build-option=-O0 --runner=tools/oclgrindrunner.sh + futhark test tests -c --no-terminal --backend=pyopencl --exclude=compiled --exclude=no_oclgrind --cache-extension=cache --pass-option=--build-option=-O0 --runner=tools/oclgrindrunner.sh test-opencl: runs-on: hendrix diff --git a/CHANGELOG.md b/CHANGELOG.md index 2269a88a2b..a1a8a1d1fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Faster floating-point atomics with OpenCL backend on AMD and NVIDIA GPUs. This affects histogram workloads. +* AD is now supported by the interpreter (thanks to Marcus Jensen). + ### Removed ### Changed diff --git a/futhark.cabal b/futhark.cabal index 6191a6bcbd..47de94224f 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -397,6 +397,7 @@ library Language.Futhark Language.Futhark.Core Language.Futhark.Interpreter + Language.Futhark.Interpreter.AD Language.Futhark.Interpreter.Values Language.Futhark.FreeVars Language.Futhark.Parser diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 9c3b3bacfb..ea2bf21524 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -56,6 +56,7 @@ import Futhark.Util.Loc import Futhark.Util.Pretty hiding (apply) import Language.Futhark hiding (Shape, matchDims) import Language.Futhark qualified as F +import Language.Futhark.Interpreter.AD qualified as AD import Language.Futhark.Interpreter.Values hiding (Value) import Language.Futhark.Interpreter.Values qualified import Language.Futhark.Primitive (floatValue, intValue) @@ -263,6 +264,9 @@ asInteger :: Value -> Integer asInteger (ValuePrim (SignedValue v)) = P.valueIntegral v asInteger (ValuePrim (UnsignedValue v)) = toInteger (P.valueIntegral (P.doZExt v Int64) :: Word64) +asInteger (ValueAD d v) + | P.IntValue v' <- AD.primitive $ AD.primal $ AD.Variable d v = + P.valueIntegral v' asInteger v = error $ "Unexpectedly not an integer: " <> show v asInt :: Value -> Int @@ -270,13 +274,17 @@ asInt = fromIntegral . asInteger asSigned :: Value -> IntValue asSigned (ValuePrim (SignedValue v)) = v -asSigned v = error $ "Unexpected not a signed integer: " <> show v +asSigned (ValueAD d v) + | P.IntValue v' <- AD.primitive $ AD.primal $ AD.Variable d v = v' +asSigned v = error $ "Unexpectedly not a signed integer: " <> show v asInt64 :: Value -> Int64 asInt64 = fromIntegral . asInteger asBool :: Value -> Bool asBool (ValuePrim (BoolValue x)) = x +asBool (ValueAD d v) + | P.BoolValue v' <- AD.primitive $ AD.primal $ AD.Variable d v = v' asBool v = error $ "Unexpectedly not a boolean: " <> show v lookupInEnv :: @@ -937,6 +945,12 @@ evalAppExp env (Match e cs _) = do Just v' -> pure v' Nothing -> match v cs' +zeroOfType :: PrimType -> Value +zeroOfType (Signed it) = ValuePrim $ SignedValue $ P.intValue it (0 :: Int) +zeroOfType (Unsigned it) = ValuePrim $ UnsignedValue $ P.intValue it (0 :: Int) +zeroOfType (FloatType ft) = ValuePrim $ FloatValue $ P.floatValue ft (0 :: Int) +zeroOfType Bool = ValuePrim $ BoolValue False + eval :: Env -> Exp -> EvalM Value eval _ (Literal v _) = pure $ ValuePrim v eval env (Hole (Info t) loc) = @@ -1008,28 +1022,15 @@ eval _ (FloatLit v (Info t) _) = Scalar (Prim (FloatType ft)) -> pure $ ValuePrim $ FloatValue $ floatValue ft v _ -> error $ "eval: nonsensical type for float literal: " <> prettyString t -eval env (Negate e _) = do - ev <- eval env e - ValuePrim <$> case ev of - ValuePrim (SignedValue (Int8Value v)) -> pure $ SignedValue $ Int8Value (-v) - ValuePrim (SignedValue (Int16Value v)) -> pure $ SignedValue $ Int16Value (-v) - ValuePrim (SignedValue (Int32Value v)) -> pure $ SignedValue $ Int32Value (-v) - ValuePrim (SignedValue (Int64Value v)) -> pure $ SignedValue $ Int64Value (-v) - ValuePrim (UnsignedValue (Int8Value v)) -> pure $ UnsignedValue $ Int8Value (-v) - ValuePrim (UnsignedValue (Int16Value v)) -> pure $ UnsignedValue $ Int16Value (-v) - ValuePrim (UnsignedValue (Int32Value v)) -> pure $ UnsignedValue $ Int32Value (-v) - ValuePrim (UnsignedValue (Int64Value v)) -> pure $ UnsignedValue $ Int64Value (-v) - ValuePrim (FloatValue (Float16Value v)) -> pure $ FloatValue $ Float16Value (-v) - ValuePrim (FloatValue (Float32Value v)) -> pure $ FloatValue $ Float32Value (-v) - ValuePrim (FloatValue (Float64Value v)) -> pure $ FloatValue $ Float64Value (-v) - _ -> error $ "Cannot negate " <> show ev -eval env (Not e _) = do - ev <- eval env e - ValuePrim <$> case ev of - ValuePrim (BoolValue b) -> pure $ BoolValue $ not b - ValuePrim (SignedValue iv) -> pure $ SignedValue $ P.doComplement iv - ValuePrim (UnsignedValue iv) -> pure $ UnsignedValue $ P.doComplement iv - _ -> error $ "Cannot logically negate " <> show ev +eval env (Negate e loc) = + -- -x = 0-x + case typeOf e of + Scalar (Prim pt) -> do + ev <- eval env e + apply2 loc env intrinsicsMinus (zeroOfType pt) ev + t -> error $ "Cannot negate expression of type " <> prettyString t +eval env (Not e loc) = + apply loc env intrinsicsNot =<< eval env e eval env (Update src is v loc) = maybe oob pure =<< writeArray <$> mapM (evalDimIndex env) is <*> eval env src <*> eval env v @@ -1276,44 +1277,44 @@ initialCtx = types = M.mapMaybeWithKey (const . tdef . baseString) intrinsics sintOp f = - [ (getS, putS, P.doBinOp (f Int8)), - (getS, putS, P.doBinOp (f Int16)), - (getS, putS, P.doBinOp (f Int32)), - (getS, putS, P.doBinOp (f Int64)) + [ (getS, putS, P.doBinOp (f Int8), adBinOp $ AD.OpBin (f Int8)), + (getS, putS, P.doBinOp (f Int16), adBinOp $ AD.OpBin (f Int16)), + (getS, putS, P.doBinOp (f Int32), adBinOp $ AD.OpBin (f Int32)), + (getS, putS, P.doBinOp (f Int64), adBinOp $ AD.OpBin (f Int64)) ] uintOp f = - [ (getU, putU, P.doBinOp (f Int8)), - (getU, putU, P.doBinOp (f Int16)), - (getU, putU, P.doBinOp (f Int32)), - (getU, putU, P.doBinOp (f Int64)) + [ (getU, putU, P.doBinOp (f Int8), adBinOp $ AD.OpBin (f Int8)), + (getU, putU, P.doBinOp (f Int16), adBinOp $ AD.OpBin (f Int16)), + (getU, putU, P.doBinOp (f Int32), adBinOp $ AD.OpBin (f Int32)), + (getU, putU, P.doBinOp (f Int64), adBinOp $ AD.OpBin (f Int64)) ] intOp f = sintOp f ++ uintOp f floatOp f = - [ (getF, putF, P.doBinOp (f Float16)), - (getF, putF, P.doBinOp (f Float32)), - (getF, putF, P.doBinOp (f Float64)) + [ (getF, putF, P.doBinOp (f Float16), adBinOp $ AD.OpBin (f Float16)), + (getF, putF, P.doBinOp (f Float32), adBinOp $ AD.OpBin (f Float32)), + (getF, putF, P.doBinOp (f Float64), adBinOp $ AD.OpBin (f Float64)) ] arithOp f g = Just $ bopDef $ intOp f ++ floatOp g - flipCmps = map (\(f, g, h) -> (f, g, flip h)) + flipCmps = map (\(f, g, h, o) -> (f, g, flip h, flip o)) sintCmp f = - [ (getS, Just . BoolValue, P.doCmpOp (f Int8)), - (getS, Just . BoolValue, P.doCmpOp (f Int16)), - (getS, Just . BoolValue, P.doCmpOp (f Int32)), - (getS, Just . BoolValue, P.doCmpOp (f Int64)) + [ (getS, Just . BoolValue, P.doCmpOp (f Int8), adBinOp $ AD.OpCmp (f Int8)), + (getS, Just . BoolValue, P.doCmpOp (f Int16), adBinOp $ AD.OpCmp (f Int16)), + (getS, Just . BoolValue, P.doCmpOp (f Int32), adBinOp $ AD.OpCmp (f Int32)), + (getS, Just . BoolValue, P.doCmpOp (f Int64), adBinOp $ AD.OpCmp (f Int64)) ] uintCmp f = - [ (getU, Just . BoolValue, P.doCmpOp (f Int8)), - (getU, Just . BoolValue, P.doCmpOp (f Int16)), - (getU, Just . BoolValue, P.doCmpOp (f Int32)), - (getU, Just . BoolValue, P.doCmpOp (f Int64)) + [ (getU, Just . BoolValue, P.doCmpOp (f Int8), adBinOp $ AD.OpCmp (f Int8)), + (getU, Just . BoolValue, P.doCmpOp (f Int16), adBinOp $ AD.OpCmp (f Int16)), + (getU, Just . BoolValue, P.doCmpOp (f Int32), adBinOp $ AD.OpCmp (f Int32)), + (getU, Just . BoolValue, P.doCmpOp (f Int64), adBinOp $ AD.OpCmp (f Int64)) ] floatCmp f = - [ (getF, Just . BoolValue, P.doCmpOp (f Float16)), - (getF, Just . BoolValue, P.doCmpOp (f Float32)), - (getF, Just . BoolValue, P.doCmpOp (f Float64)) + [ (getF, Just . BoolValue, P.doCmpOp (f Float16), adBinOp $ AD.OpCmp (f Float16)), + (getF, Just . BoolValue, P.doCmpOp (f Float32), adBinOp $ AD.OpCmp (f Float32)), + (getF, Just . BoolValue, P.doCmpOp (f Float64), adBinOp $ AD.OpCmp (f Float64)) ] - boolCmp f = [(getB, Just . BoolValue, P.doCmpOp f)] + boolCmp f = [(getB, Just . BoolValue, P.doCmpOp f, adBinOp $ AD.OpCmp f)] getV (SignedValue x) = Just $ P.IntValue x getV (UnsignedValue x) = Just $ P.IntValue x @@ -1344,6 +1345,17 @@ initialCtx = putB (P.BoolValue x) = Just $ BoolValue x putB _ = Nothing + getAD (ValuePrim v) = AD.Constant <$> getV v + getAD (ValueAD d v) = Just $ AD.Variable d v + getAD _ = Nothing + putAD (AD.Variable d s) = ValueAD d s + putAD (AD.Constant v) = ValuePrim $ putV v + + adToPrim v = putV $ AD.primitive v + + adBinOp op x y = AD.doOp op [x, y] + adUnOp op x = AD.doOp op [x] + fun1 f = TermValue Nothing $ ValueFun $ \x -> f x @@ -1408,6 +1420,12 @@ initialCtx = | Just z <- msum $ map (`bopDef'` (x', y')) fs -> do breakOnNaN [x', y'] z pure $ ValuePrim z + _ + | Just x' <- getAD x, + Just y' <- getAD y, + Just z <- msum $ map (`bopDefAD'` (x', y')) fs -> do + breakOnNaN [adToPrim x', adToPrim y'] $ adToPrim z + pure $ putAD z _ -> bad noLoc mempty . docText $ "Cannot apply operator to arguments" @@ -1416,10 +1434,11 @@ initialCtx = <+> dquotes (prettyValue y) <> "." where - bopDef' (valf, retf, op) (x, y) = do + bopDef' (valf, retf, op, _) (x, y) = do x' <- valf x y' <- valf y retf =<< op x' y' + bopDefAD' (_, _, _, dop) (x, y) = dop x y unopDef fs = fun1 $ \x -> case x of @@ -1427,17 +1446,23 @@ initialCtx = | Just r <- msum $ map (`unopDef'` x') fs -> do breakOnNaN [x'] r pure $ ValuePrim r + _ + | Just x' <- getAD x, + Just r <- msum $ map (`unopDefAD'` x') fs -> do + breakOnNaN [adToPrim x'] $ adToPrim r + pure $ putAD r _ -> bad noLoc mempty . docText $ "Cannot apply function to argument" <+> dquotes (prettyValue x) <> "." where - unopDef' (valf, retf, op) x = do + unopDef' (valf, retf, op, _) x = do x' <- valf x retf =<< op x' + unopDefAD' (_, _, _, dop) = dop - tbopDef f = fun1 $ \v -> + tbopDef op f = fun1 $ \v -> case fromTuple v of Just [ValuePrim x, ValuePrim y] | Just x' <- getV x, @@ -1445,6 +1470,12 @@ initialCtx = Just z <- putV <$> f x' y' -> do breakOnNaN [x, y] z pure $ ValuePrim z + Just [x, y] + | Just x' <- getAD x, + Just y' <- getAD y, + Just z <- AD.doOp op [x', y'] -> do + breakOnNaN [adToPrim x', adToPrim y'] $ adToPrim z + pure $ putAD z _ -> bad noLoc mempty . docText $ "Cannot apply operator to argument" @@ -1454,15 +1485,15 @@ initialCtx = def "!" = Just $ unopDef - [ (getS, putS, P.doUnOp $ P.Complement Int8), - (getS, putS, P.doUnOp $ P.Complement Int16), - (getS, putS, P.doUnOp $ P.Complement Int32), - (getS, putS, P.doUnOp $ P.Complement Int64), - (getU, putU, P.doUnOp $ P.Complement Int8), - (getU, putU, P.doUnOp $ P.Complement Int16), - (getU, putU, P.doUnOp $ P.Complement Int32), - (getU, putU, P.doUnOp $ P.Complement Int64), - (getB, putB, P.doUnOp P.Not) + [ (getS, putS, P.doUnOp $ P.Complement Int8, adUnOp $ AD.OpUn $ P.Complement Int8), + (getS, putS, P.doUnOp $ P.Complement Int16, adUnOp $ AD.OpUn $ P.Complement Int16), + (getS, putS, P.doUnOp $ P.Complement Int32, adUnOp $ AD.OpUn $ P.Complement Int32), + (getS, putS, P.doUnOp $ P.Complement Int64, adUnOp $ AD.OpUn $ P.Complement Int64), + (getU, putU, P.doUnOp $ P.Complement Int8, adUnOp $ AD.OpUn $ P.Complement Int8), + (getU, putU, P.doUnOp $ P.Complement Int16, adUnOp $ AD.OpUn $ P.Complement Int16), + (getU, putU, P.doUnOp $ P.Complement Int32, adUnOp $ AD.OpUn $ P.Complement Int32), + (getU, putU, P.doUnOp $ P.Complement Int64, adUnOp $ AD.OpUn $ P.Complement Int64), + (getB, putB, P.doUnOp P.Not, adUnOp $ AD.OpUn P.Not) ] def "+" = arithOp (`P.Add` P.OverflowWrap) P.FAdd def "-" = arithOp (`P.Sub` P.OverflowWrap) P.FSub @@ -1542,16 +1573,16 @@ initialCtx = ++ boolCmp P.CmpLle def s | Just bop <- find ((s ==) . prettyString) P.allBinOps = - Just $ tbopDef $ P.doBinOp bop + Just $ tbopDef (AD.OpBin bop) $ P.doBinOp bop | Just unop <- find ((s ==) . prettyString) P.allCmpOps = - Just $ tbopDef $ \x y -> P.BoolValue <$> P.doCmpOp unop x y + Just $ tbopDef (AD.OpCmp unop) $ \x y -> P.BoolValue <$> P.doCmpOp unop x y | Just cop <- find ((s ==) . prettyString) P.allConvOps = - Just $ unopDef [(getV, Just . putV, P.doConvOp cop)] + Just $ unopDef [(getV, Just . putV, P.doConvOp cop, adUnOp $ AD.OpConv cop)] | Just unop <- find ((s ==) . prettyString) P.allUnOps = - Just $ unopDef [(getV, Just . putV, P.doUnOp unop)] + Just $ unopDef [(getV, Just . putV, P.doUnOp unop, adUnOp $ AD.OpUn unop)] | Just (pts, _, f) <- M.lookup s P.primFuns = case length pts of - 1 -> Just $ unopDef [(getV, Just . putV, f . pure)] + 1 -> Just $ unopDef [(getV, Just . putV, f . pure, adUnOp $ AD.OpFn s)] _ -> Just $ fun1 $ \x -> do let getV' (ValuePrim v) = Just v @@ -1742,14 +1773,22 @@ initialCtx = ( ValueAcc shape op acc_arr, ValuePrim (SignedValue (Int64Value i')) ) -> - if i' >= 0 && i' < arrayLength acc_arr - then do - let x = acc_arr ! fromIntegral i' - res <- op x v - pure $ ValueAcc shape op $ acc_arr // [(fromIntegral i', res)] - else pure acc + write acc v shape op acc_arr i' + ( ValueAcc shape op acc_arr, + adv@(ValueAD {}) + ) + | Just (SignedValue (Int64Value i')) <- putV . AD.primitive <$> getAD adv -> + write acc v shape op acc_arr i' _ -> error $ "acc_write invalid arguments: " <> prettyString (show acc, show i, show v) + where + write acc v shape op acc_arr i' = + if i' >= 0 && i' < arrayLength acc_arr + then do + let x = acc_arr ! fromIntegral i' + res <- op x v + pure $ ValueAcc shape op $ acc_arr // [(fromIntegral i', res)] + else pure acc -- def "flat_index_2d" = Just . fun6 $ \arr offset n1 s1 n2 s2 -> do let offset' = asInt64 offset @@ -1926,11 +1965,129 @@ initialCtx = else pure $ toArray shape $ map (toArray rowshape) $ chunk (asInt m) xs' def "manifest" = Just $ fun1 pure def "vjp2" = Just $ - fun3 $ - \_ _ _ -> bad noLoc mempty "Interpreter does not support autodiff." + -- TODO: This could be much better. Currently, it is very inefficient + -- Perhaps creating VJPValues could be abstracted into a function + -- exposed by the AD module? + fun3 $ \f v s -> do + -- Get the depth + depth <- length <$> stacktrace + + -- Augment the values + let v' = + fromMaybe (error $ "vjp: invalid values " ++ show v) $ + modifyValueM (\i lv -> ValueAD depth . AD.VJP . AD.VJPValue . AD.TapeID i <$> getAD lv) v + -- Turn the seeds into a list of ADValues + let s' = + fromMaybe (error $ "vjp: invalid seeds " ++ show s) $ + mapM getAD $ + fst $ + valueAccum (\a b -> (b : a, b)) [] s + + -- Run the function, and turn its outputs into a list of Values + o <- apply noLoc mempty f v' + let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o + + -- For each output.. + let m = + fromMaybe (error "vjp: differentiation failed") $ + zipWithM + ( \on sn -> case on of + -- If it is a VJP variable of the correct depth, run deriveTape on it- and its corresponding seed + (ValueAD d (AD.VJP (AD.VJPValue t))) | d == depth -> (putAD $ AD.tapePrimal t,) <$> AD.deriveTape t sn + -- Otherwise, its partial derivatives are all 0 + _ -> Just (on, M.empty) + ) + o' + s' + + -- Add together every derivative + let drvs = M.map (Just . putAD) $ M.unionsWith add $ map snd m + + -- Extract the output values, and the partial derivatives + let ov = modifyValue (\i _ -> fst $ m !! i) o + let od = + fromMaybe (error "vjp: differentiation failed") $ + modifyValueM (\i vo -> M.findWithDefault (ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD vo) i drvs) v + + -- Return a tuple of the output values, and partial derivatives + pure $ toTuple [ov, od] + where + modifyValue f v = snd $ valueAccum (\a b -> (a + 1, f a b)) 0 v + modifyValueM f v = + snd + <$> valueAccumLM + ( \a b -> do + b' <- f a b + pure (a + 1, b') + ) + 0 + v + + -- TODO: Perhaps this could be fully abstracted by AD? + -- Making addFor private would be nice.. + add x y = + fromMaybe (error "jvp: illtyped add") $ + AD.doOp (AD.OpBin $ AD.addFor $ P.primValueType $ AD.primitive x) [x, y] def "jvp2" = Just $ - fun3 $ - \_ _ _ -> bad noLoc mempty "Interpreter does not support autodiff." + -- TODO: This could be much better. Currently, it is very inefficient + -- Perhaps creating JVPValues could be abstracted into a function + -- exposed by the AD module? + fun3 $ \f v s -> do + -- Get the depth + depth <- length <$> stacktrace + + -- Turn the seeds into a list of ADValues + let s' = + expectJust ("jvp: invalid seeds " ++ show s) $ + mapM getAD $ + fst $ + valueAccum (\a b -> (b : a, b)) [] s + -- Augment the values + let v' = + expectJust ("jvp: invalid values " ++ show v) $ + modifyValueM + ( \i lv -> do + lv' <- getAD lv + pure $ ValueAD depth . AD.JVP . AD.JVPValue lv' $ s' !! (length s' - 1 - i) + ) + v + + -- Run the function, and turn its outputs into a list of Values + o <- apply noLoc mempty f v' + let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o + + -- For each output.. + let m = + expectJust "jvp: differentiation failed" $ + mapM + ( \on -> case on of + -- If it is a JVP variable of the correct depth, return its primal and derivative + (ValueAD d (AD.JVP (AD.JVPValue pv dv))) | d == depth -> Just (putAD pv, putAD dv) + -- Otherwise, its partial derivatives are all 0 + _ -> (on,) . ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD on + ) + o' + + -- Extract the output values, and the partial derivatives + let ov = modifyValue (\i _ -> fst $ m !! (length m - 1 - i)) o + od = modifyValue (\i _ -> snd $ m !! (length m - 1 - i)) o + + -- Return a tuple of the output values, and partial derivatives + pure $ toTuple [ov, od] + where + modifyValue f v = snd $ valueAccum (\a b -> (a + 1, f a b)) 0 v + modifyValueM f v = + snd + <$> valueAccumLM + ( \a b -> do + b' <- f a b + pure (a + 1, b') + ) + 0 + v + + expectJust _ (Just v) = v + expectJust s Nothing = error s def "acc" = Nothing def s | nameFromString s `M.member` namesToPrimTypes = Nothing def s = error $ "Missing intrinsic: " ++ s @@ -1944,6 +2101,18 @@ initialCtx = in apply2 noLoc mempty f n arg stream _ arg = error $ "Cannot stream: " <> show arg +intrinsicVal :: Name -> Value +intrinsicVal name = + case M.lookup (intrinsicVar name) $ envTerm $ ctxEnv initialCtx of + Just (TermValue _ v) -> v + _ -> error $ "intrinsicVal: " <> prettyString name + +intrinsicsMinus :: Value +intrinsicsMinus = intrinsicVal "-" + +intrinsicsNot :: Value +intrinsicsNot = intrinsicVal "!" + interpretExp :: Ctx -> Exp -> F ExtOp Value interpretExp ctx e = runEvalM (ctxImports ctx) $ eval (ctxEnv ctx) e diff --git a/src/Language/Futhark/Interpreter/AD.hs b/src/Language/Futhark/Interpreter/AD.hs new file mode 100644 index 0000000000..525013ec62 --- /dev/null +++ b/src/Language/Futhark/Interpreter/AD.hs @@ -0,0 +1,320 @@ +module Language.Futhark.Interpreter.AD + ( Op (..), + ADVariable (..), + ADValue (..), + Tape (..), + VJPValue (..), + JVPValue (..), + doOp, + addFor, + primal, + tapePrimal, + primitive, + deriveTape, + ) +where + +import Control.Monad (foldM, zipWithM) +import Data.Either (isRight) +import Data.List (find) +import Data.Map qualified as M +import Data.Maybe (fromMaybe) +import Futhark.AD.Derivatives (pdBinOp, pdBuiltin, pdUnOp) +import Futhark.Analysis.PrimExp (PrimExp (..)) +import Language.Futhark.Core (VName (..), nameFromString) +import Language.Futhark.Primitive + +-- Mathematical operations subject to AD. +data Op + = OpBin BinOp + | OpCmp CmpOp + | OpUn UnOp + | OpFn String + | OpConv ConvOp + deriving (Show) + +-- Checks if an operation matches the types of its operands +opTypeMatch :: Op -> [PrimType] -> Bool +opTypeMatch (OpBin op) p = all (\x -> binOpType op == x) p +opTypeMatch (OpCmp op) p = all (\x -> cmpOpType op == x) p +opTypeMatch (OpUn op) p = all (\x -> unOpType op == x) p +opTypeMatch (OpConv op) p = all (\x -> fst (convOpType op) == x) p +opTypeMatch (OpFn fn) p = case M.lookup fn primFuns of + Just (t, _, _) -> and $ zipWith (==) t p + Nothing -> error "opTypeMatch" -- It is assumed that the function exists + +-- Gets the return type of an operation +opReturnType :: Op -> PrimType +opReturnType (OpBin op) = binOpType op +opReturnType (OpCmp op) = cmpOpType op +opReturnType (OpUn op) = unOpType op +opReturnType (OpConv op) = snd $ convOpType op +opReturnType (OpFn fn) = case M.lookup fn primFuns of + Just (_, t, _) -> t + Nothing -> error "opReturnType" -- It is assumed that the function exists + +-- Returns the operation which performs addition (or an +-- equivalent operation) on the given type +addFor :: PrimType -> BinOp +addFor (IntType t) = Add t OverflowWrap +addFor (FloatType t) = FAdd t +addFor Bool = LogOr +addFor t = error $ "addFor: " ++ show t + +-- Returns the function which performs multiplication +-- (or an equivalent operation) on the given type +mulFor :: PrimType -> BinOp +mulFor (IntType t) = Mul t OverflowWrap +mulFor (FloatType t) = FMul t +mulFor Bool = LogAnd +mulFor t = error $ "mulFor: " ++ show t + +-- Types and utility functions-- +-- When taking the partial derivative of a function, we +-- must differentiate between the values which are kept +-- constant, and those which are not +data ADValue + = Variable Int ADVariable + | Constant PrimValue + deriving (Show) + +-- When performing automatic differentiation, each derived +-- variable must be augmented with additional data. This +-- value holds the primitive value of the variable, as well +-- as its data +data ADVariable + = VJP VJPValue + | JVP JVPValue + deriving (Show) + +depth :: ADValue -> Int +depth (Variable d _) = d +depth (Constant _) = 0 + +primal :: ADValue -> ADValue +primal (Variable _ (VJP (VJPValue t))) = tapePrimal t +primal (Variable _ (JVP (JVPValue v _))) = primal v +primal (Constant v) = Constant v + +primitive :: ADValue -> PrimValue +primitive v@(Variable _ _) = primitive $ primal v +primitive (Constant v) = v + +-- Evaluates a PrimExp using doOp +evalPrimExp :: M.Map VName ADValue -> PrimExp VName -> Maybe ADValue +evalPrimExp m (LeafExp n _) = M.lookup n m +evalPrimExp _ (ValueExp pv) = Just $ Constant pv +evalPrimExp m (BinOpExp op x y) = do + x' <- evalPrimExp m x + y' <- evalPrimExp m y + doOp (OpBin op) [x', y'] +evalPrimExp m (CmpOpExp op x y) = do + x' <- evalPrimExp m x + y' <- evalPrimExp m y + doOp (OpCmp op) [x', y'] +evalPrimExp m (UnOpExp op x) = do + x' <- evalPrimExp m x + doOp (OpUn op) [x'] +evalPrimExp m (ConvOpExp op x) = do + x' <- evalPrimExp m x + doOp (OpConv op) [x'] +evalPrimExp m (FunExp fn p _) = do + p' <- mapM (evalPrimExp m) p + doOp (OpFn fn) p' + +-- Returns a list of PrimExps calculating the partial +-- derivative of each operands of a given operation +lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName] +lookupPDs (OpBin op) [x, y] = Just $ do + let (a, b) = pdBinOp op x y + [a, b] +lookupPDs (OpUn op) [x] = Just [pdUnOp op x] +lookupPDs (OpFn fn) p = pdBuiltin (nameFromString fn) p +lookupPDs _ _ = Nothing + +-- Shared AD logic-- +-- This function performs a mathematical operation on a +-- list of operands, performing automatic differentiation +-- if one or more operands is a Variable (of depth > 0) +doOp :: Op -> [ADValue] -> Maybe ADValue +doOp op o + | not $ opTypeMatch op (map primValueType pv) = + -- This function may be called with arguments of invalid types, + -- because it is used as part of an overloaded operator. + Nothing + | otherwise = do + let dep = case op of + OpCmp _ -> 0 -- AD is not well-defined for comparason operations + -- There are no derivatives for those written in + -- PrimExp (check lookupPDs) + _ -> maximum (map depth o) + if dep == 0 then constCase else nonconstCase dep + where + pv = map primitive o + + divideDepths :: Int -> ADValue -> Either ADValue ADVariable + divideDepths _ v@(Constant {}) = Left v + divideDepths d v@(Variable d' v') = if d' < d then Left v else Right v' + + -- TODO: There may be a more graceful way of + -- doing this + extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue + extractVJP (Right (VJP v)) = Right v + extractVJP (Left v) = Left v + extractVJP _ = + -- This will never be called when the maximum depth layer is JVP + error "extractVJP" + + -- TODO: There may be a more graceful way of + -- doing this + extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue + extractJVP (Right (JVP v)) = Right v + extractJVP (Left v) = Left v + extractJVP _ = + -- This will never be called when the maximum depth layer is VJP + error "extractJVP" + + -- In this case, every operand is a constant, and the + -- mathematical operation can be applied as it would be + -- otherwise + constCase = + Constant <$> case (op, pv) of + (OpBin op', [x, y]) -> doBinOp op' x y + (OpCmp op', [x, y]) -> BoolValue <$> doCmpOp op' x y + (OpUn op', [x]) -> doUnOp op' x + (OpConv op', [x]) -> doConvOp op' x + (OpFn fn, _) -> do + (_, _, f) <- M.lookup fn primFuns + f pv + _ -> error "doOp: opTypeMatch" + + nonconstCase dep = do + -- In this case, some values are variables. We therefore + -- have to perform the necessary steps for AD + + -- First, we calculate the value for the previous depth + let oprev = map primal o + vprev <- doOp op oprev + + -- Then we separate the values of the maximum depth from + -- those of a lower depth + let o' = map (divideDepths dep) o + -- Then we find out what type of AD is being performed + case find isRight o' of + -- Finally, we perform the necessary steps for the given + -- type of AD + Just (Right (VJP {})) -> + Just . Variable dep . VJP . VJPValue $ vjpHandleOp op (map extractVJP o') vprev + Just (Right (JVP {})) -> + Variable dep . JVP . JVPValue vprev <$> jvpHandleFn op (map extractJVP o') + _ -> + -- Since the maximum depth is non-zero, there must be at + -- least one variable of depth > 0 + error "find isRight" + +calculatePDs :: Op -> [ADValue] -> Maybe [ADValue] +calculatePDs op p = do + -- Create a unique VName for each operand + let n = map (\i -> VName (nameFromString $ "x" ++ show i) i) [1 .. length p] + -- Put the operands in the environment + let m = M.fromList $ zip n p + + -- Look up, and calculate the partial derivative + -- of the operation with respect to each operand + pde <- lookupPDs op $ map (`LeafExp` opReturnType op) n + mapM (evalPrimExp m) pde + +-- VJP / Reverse mode automatic differentiation-- +-- In reverse mode AD, the entire computation +-- leading up to a variable must be saved +-- This is represented as a Tape +newtype VJPValue = VJPValue Tape + deriving (Show) + +-- | Represents a computation tree, as well as every intermediate +-- value in its evaluation. TODO: make this a graph. +data Tape + = -- | This represents a variable. Each variable is given a unique ID, + -- and has an initial value + TapeID Int ADValue + | -- | This represents a constant. + TapeConst ADValue + | -- | This represents the application of a mathematical operation. + -- Each parameter is given by its Tape, and the return value of + -- the operation is saved + TapeOp Op [Tape] ADValue + deriving (Show) + +-- | Returns the primal value of a Tape. +tapePrimal :: Tape -> ADValue +tapePrimal (TapeID _ v) = v +tapePrimal (TapeConst v) = v +tapePrimal (TapeOp _ _ v) = v + +-- This updates Tape of a VJPValue with a new operation, +-- treating all operands of a lower depth as constants +vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> Tape +vjpHandleOp op p v = do + TapeOp op (map toTape p) v + where + toTape (Left v') = TapeConst v' + toTape (Right (VJPValue t)) = t + +-- | This calculates every partial derivative of a 'Tape'. The result +-- is a map of the partial derivatives, each key corresponding to the +-- ID of a free variable (see TapeID). +deriveTape :: Tape -> ADValue -> Maybe (M.Map Int ADValue) +deriveTape (TapeID i _) s = Just $ M.fromList [(i, s)] +deriveTape (TapeConst _) _ = Just M.empty +deriveTape (TapeOp op p _) s = do + -- Calculate the new sensitivities + s'' <- case op of + OpConv op' -> do + -- In case of type conversion, simply convert the sensitivity + s' <- doOp (OpConv $ flipConvOp op') [s] + Just [s'] + _ -> do + pds <- calculatePDs op $ map tapePrimal p + mapM (mul s) pds + + -- Propagate the new sensitivities + pd <- zipWithM deriveTape p s'' + -- Add up the results + Just $ foldl (M.unionWith add) M.empty pd + where + add x y = + fromMaybe (error "deriveTape: addition failed") $ + doOp (OpBin $ addFor $ opReturnType op) [x, y] + mul x y = doOp (OpBin $ mulFor $ opReturnType op) [x, y] + +-- JVP / Forward mode automatic differentiation-- + +-- | In JVP, the derivative of the variable must be saved. This is +-- represented as a second value. +data JVPValue = JVPValue ADValue ADValue + deriving (Show) + +-- | This calculates the derivative part of the JVPValue resulting +-- from the application of a mathematical operation on one or more +-- JVPValues. +jvpHandleFn :: Op -> [Either ADValue JVPValue] -> Maybe ADValue +jvpHandleFn op p = do + case op of + OpConv _ -> + -- In case of type conversion, simply convert + -- the old derivative + doOp op [derivative $ head p] + _ -> do + -- Calculate the new derivative using the chain + -- rule + pds <- calculatePDs op $ map primal' p + vs <- zipWithM mul pds $ map derivative p + foldM add (Constant $ blankPrimValue $ opReturnType op) vs + where + primal' (Left v) = v + primal' (Right (JVPValue v _)) = v + derivative (Left v) = Constant $ blankPrimValue $ primValueType $ primitive v + derivative (Right (JVPValue _ d)) = d + + add x y = doOp (OpBin $ addFor $ opReturnType op) [x, y] + mul x y = doOp (OpBin $ mulFor $ opReturnType op) [x, y] diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index b03f57cbbb..6b1863ce18 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -14,6 +14,8 @@ module Language.Futhark.Interpreter.Values valueShape, prettyValue, valueText, + valueAccum, + valueAccumLM, fromTuple, arrayLength, isEmptyArray, @@ -28,6 +30,7 @@ module Language.Futhark.Interpreter.Values where import Data.Array +import Data.Bifunctor (Bifunctor (second)) import Data.List (genericLength) import Data.Map qualified as M import Data.Maybe @@ -35,9 +38,10 @@ import Data.Monoid hiding (Sum) import Data.Text qualified as T import Data.Vector.Storable qualified as SVec import Futhark.Data qualified as V -import Futhark.Util (chunk) +import Futhark.Util (chunk, mapAccumLM) import Futhark.Util.Pretty import Language.Futhark hiding (Shape, matchDims) +import Language.Futhark.Interpreter.AD qualified as AD import Language.Futhark.Primitive qualified as P import Prelude hiding (break, mod) @@ -106,6 +110,8 @@ data Value m ValueSum ValueShape Name [Value m] | -- The shape, the update function, and the array. ValueAcc ValueShape (Value m -> Value m -> m (Value m)) !(Array Int (Value m)) + | -- A primitive value with added information used in automatic differentiation + ValueAD Int AD.ADVariable instance Show (Value m) where show (ValuePrim v) = "ValuePrim " <> show v <> "" @@ -114,6 +120,7 @@ instance Show (Value m) where show (ValueSum shape c vs) = unwords ["ValueSum", "(" <> show shape <> ")", show c, "(" <> show vs <> ")"] show ValueFun {} = "ValueFun _" show ValueAcc {} = "ValueAcc _" + show (ValueAD d v) = unwords ["ValueAD", show d, show v] instance Eq (Value m) where ValuePrim (SignedValue x) == ValuePrim (SignedValue y) = @@ -145,6 +152,8 @@ prettyValueWith pprPrim = pprPrec 0 pprPrec _ ValueAcc {} = "#" pprPrec p (ValueSum _ n vs) = parensIf (p > (0 :: Int)) $ "#" <> sep (pretty n : map (pprPrec 1) vs) + -- TODO: This could be prettier. Perhaps add pretty printing for ADVariable / ADValues + pprPrec _ (ValueAD d v) = pretty $ "d[" ++ show d ++ "]" ++ show v pprElem v@ValueArray {} = pprPrec 0 v pprElem v = group $ pprPrec 0 v @@ -182,6 +191,40 @@ valueShape (ValueRecord fs) = ShapeRecord $ M.map valueShape fs valueShape (ValueSum shape _ _) = shape valueShape _ = ShapeLeaf +-- TODO: Perhaps there is some clever way to reuse the code between +-- valueAccum and valueAccumLM +valueAccum :: (a -> Value m -> (a, Value m)) -> a -> Value m -> (a, Value m) +valueAccum f i v@(ValuePrim {}) = f i v +valueAccum f i v@(ValueAD {}) = f i v +valueAccum f i (ValueRecord m) = second ValueRecord $ M.mapAccum (valueAccum f) i m +valueAccum f i (ValueArray s a) = do + -- TODO: This could probably be better + -- Transform into a map + let m = M.fromList $ assocs a + -- Accumulate over the map + let (i', m') = M.mapAccum (valueAccum f) i m + -- Transform back into an array and return + let a' = array (bounds a) (M.toList m') + (i', ValueArray s a') +valueAccum _ _ v = error $ "valueAccum not implemented for " ++ show v + +valueAccumLM :: (Monad f) => (a -> Value m -> f (a, Value m)) -> a -> Value m -> f (a, Value m) +valueAccumLM f i v@(ValuePrim {}) = f i v +valueAccumLM f i v@(ValueAD {}) = f i v +valueAccumLM f i (ValueRecord m) = do + (a, b) <- mapAccumLM (valueAccumLM f) i m + pure (a, ValueRecord b) +valueAccumLM f i (ValueArray s a) = do + -- TODO: This could probably be better + -- Transform into a map + let m = M.fromList $ assocs a + -- Accumulate over the map + (i', m') <- mapAccumLM (valueAccumLM f) i m + -- Transform back into an array and return + let a' = array (bounds a) (M.toList m') + pure (i', ValueArray s a') +valueAccumLM _ _ v = error $ "valueAccum not implemented for " ++ show v + -- | Does the value correspond to an empty array? isEmptyArray :: Value m -> Bool isEmptyArray = emptyShape . valueShape diff --git a/tests/ad/arr0.fut b/tests/ad/arr0.fut index 6b67c7b6e4..f20cfe1b54 100644 --- a/tests/ad/arr0.fut +++ b/tests/ad/arr0.fut @@ -2,7 +2,7 @@ def f (xs: [2]f64) = xs[0] * xs[1] -- == -- entry: f_jvp --- compiled input { [5.0, 7.0] } +-- input { [5.0, 7.0] } -- output { 7.0 5.0 } entry f_jvp xs = @@ -11,7 +11,7 @@ entry f_jvp xs = -- == -- entry: f_vjp --- compiled input { [5.0, 7.0] } +-- input { [5.0, 7.0] } -- output { [7.0, 5.0] } entry f_vjp xs = diff --git a/tests/ad/arr1.fut b/tests/ad/arr1.fut index df135a67a5..bc82060033 100644 --- a/tests/ad/arr1.fut +++ b/tests/ad/arr1.fut @@ -2,7 +2,7 @@ def f (x, y) : [2]f64 = [x+y, x*y] -- == -- entry: f_vjp f_jvp --- compiled input { 5.0 7.0 } +-- input { 5.0 7.0 } -- output { [1.0,7.0] [1.0, 5.0] } entry f_jvp x y = diff --git a/tests/ad/arr2.fut b/tests/ad/arr2.fut index 23248bf398..e382b3cf02 100644 --- a/tests/ad/arr2.fut +++ b/tests/ad/arr2.fut @@ -2,7 +2,7 @@ def f (x, y) : [2][1]f64 = [x, y] -- == -- entry: f_vjp f_jvp --- compiled input { [5.0] [7.0] } +-- input { [5.0] [7.0] } -- output { [[1.0],[0.0]] [[0.0], [1.0]] } entry f_jvp x y = diff --git a/tests/ad/concat0.fut b/tests/ad/concat0.fut index c3e7f3e0bf..d80440ad2a 100644 --- a/tests/ad/concat0.fut +++ b/tests/ad/concat0.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp --- compiled input { [1,2,3] [4,5,6] } +-- input { [1,2,3] [4,5,6] } -- output { [1,2,3,4,5,6] } entry f_jvp xs ys : []i32 = @@ -8,7 +8,7 @@ entry f_jvp xs ys : []i32 = -- == -- entry: f_vjp --- compiled input { [1,2,3] [4,5,6] } +-- input { [1,2,3] [4,5,6] } -- output { [1,2,3] [4,5,6] } entry f_vjp xs ys : ([]i32, []i32) = diff --git a/tests/ad/confusion0.fut b/tests/ad/confusion0.fut index 4e861ceb43..d745b47ca2 100644 --- a/tests/ad/confusion0.fut +++ b/tests/ad/confusion0.fut @@ -1,6 +1,6 @@ -- == -- entry: fwd rev --- compiled input { 1 2 } output { 1 } +-- input { 1 2 } output { 1 } def d f x = jvp f x 1 diff --git a/tests/ad/consume0.fut b/tests/ad/consume0.fut index 04415c13d2..0a61021991 100644 --- a/tests/ad/consume0.fut +++ b/tests/ad/consume0.fut @@ -1,6 +1,6 @@ -- == -- entry: rev fwd --- compiled input { [1.0,2.0,3.0] } +-- input { [1.0,2.0,3.0] } -- output { [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] } def f (xs: []f64) = diff --git a/tests/ad/consume1.fut b/tests/ad/consume1.fut index 9df6e566f2..c54c45c1c7 100644 --- a/tests/ad/consume1.fut +++ b/tests/ad/consume1.fut @@ -1,6 +1,6 @@ -- == -- entry: rev fwd --- compiled input { true [1.0,2.0,3.0] } +-- input { true [1.0,2.0,3.0] } -- output { [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] } def f b (xs: []f64) = diff --git a/tests/ad/consume2.fut b/tests/ad/consume2.fut index 199ae2d3e7..477f2fbe9b 100644 --- a/tests/ad/consume2.fut +++ b/tests/ad/consume2.fut @@ -1,6 +1,6 @@ -- == -- entry: rev fwd --- compiled input { [true] [[1.0,2.0,3.0]] [[0.0,1.0,0.0]] } +-- input { [true] [[1.0,2.0,3.0]] [[0.0,1.0,0.0]] } -- output { [[0.000000f64, 1.000000f64, 0.000000f64]] } def f b (xs: []f64) = diff --git a/tests/ad/consume3.fut b/tests/ad/consume3.fut index 27063d96f8..efbe4fe209 100644 --- a/tests/ad/consume3.fut +++ b/tests/ad/consume3.fut @@ -5,12 +5,12 @@ def test [n] (xs: [n]f64) = -- == -- entry: prim --- compiled input { [5.0, 7.0, 9.0] } +-- input { [5.0, 7.0, 9.0] } -- output { [5.0, 7.0, 9.0] } entry prim [n] (xs: [n]f64) = test xs -- == -- entry: f_vjp --- compiled input { [5.0, 7.0, 9.0] } +-- input { [5.0, 7.0, 9.0] } -- output { [1.0, 1.0, 1.0] } entry f_vjp [n] (xs: [n]f64) = vjp test xs (replicate n 1) diff --git a/tests/ad/consume4.fut b/tests/ad/consume4.fut index 0efe910a6c..3cbe37c50e 100644 --- a/tests/ad/consume4.fut +++ b/tests/ad/consume4.fut @@ -5,18 +5,18 @@ def test [n] (xs: [n]i32) = -- == -- entry: prim --- compiled input { [5, 7, 9] } +-- input { [5, 7, 9] } -- output { [5, 14, 9] } entry prim [n] (xs: [n]i32) = test xs -- == -- entry: f_vjp --- compiled input { [5, 7, 9] } +-- input { [5, 7, 9] } -- output { [1, 2, 1] } entry f_vjp [n] (xs: [n]i32) = vjp test xs (replicate n 1) -- == -- entry: f_jvp --- compiled input { [5, 7, 9] } +-- input { [5, 7, 9] } -- output { [1, 2, 1] } entry f_jvp [n] (xs: [n]i32) = jvp test xs (replicate n 1) diff --git a/tests/ad/consume5.fut b/tests/ad/consume5.fut index ed03ceb3aa..08ac477bb4 100644 --- a/tests/ad/consume5.fut +++ b/tests/ad/consume5.fut @@ -5,6 +5,6 @@ def test [n] (xs: [n]i32) = -- == -- entry: f_vjp --- compiled input { [1, 2, 3] } +-- input { [1, 2, 3] } -- output { [4, 2, 4] } entry f_vjp [n] (xs: [n]i32) = vjp test xs (replicate n 1) diff --git a/tests/ad/consume6.fut b/tests/ad/consume6.fut index 94bb2d6d0e..c893f6ba1d 100644 --- a/tests/ad/consume6.fut +++ b/tests/ad/consume6.fut @@ -7,10 +7,10 @@ def test [n] (xs: [n]i32) = -- == -- entry: prim --- compiled input { [1,2,3,4,5] } output { [1,1,1,1,1] } +-- input { [1,2,3,4,5] } output { [1,1,1,1,1] } entry prim [n] (xs: [n]i32) = test xs -- == -- entry: f_vjp --- compiled input { [1,2,3,4,5] } output { [0,0,0,0,0] } +-- input { [1,2,3,4,5] } output { [0,0,0,0,0] } entry f_vjp [n] (xs: [n]i32) = vjp test xs (replicate n 1) diff --git a/tests/ad/conv0.fut b/tests/ad/conv0.fut index ab78a70ad4..a242ab6274 100644 --- a/tests/ad/conv0.fut +++ b/tests/ad/conv0.fut @@ -1,6 +1,6 @@ -- == -- entry: fwd --- compiled input { 1.0 } +-- input { 1.0 } -- output { 1f32 } entry fwd x = @@ -8,7 +8,7 @@ entry fwd x = -- == -- entry: rev --- compiled input { 1.0 } +-- input { 1.0 } -- output { 1f64 } entry rev x = diff --git a/tests/ad/conv1.fut b/tests/ad/conv1.fut index 8cfbf04380..c812ef7c31 100644 --- a/tests/ad/conv1.fut +++ b/tests/ad/conv1.fut @@ -1,6 +1,6 @@ -- == -- entry: fwd --- compiled input { 1f64 } +-- input { 1f64 } -- output { 1i32 } entry fwd x = @@ -8,7 +8,7 @@ entry fwd x = -- == -- entry: rev --- compiled input { 1f64 } +-- input { 1f64 } -- output { 2f64 } entry rev x = diff --git a/tests/ad/fadd.fut b/tests/ad/fadd.fut index 369a34df56..903c283135 100644 --- a/tests/ad/fadd.fut +++ b/tests/ad/fadd.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp f_vjp --- compiled input { 5.0 7.0 } +-- input { 5.0 7.0 } -- output { 1.0 1.0 } def f (x,y) = x + y : f64 diff --git a/tests/ad/fdiv.fut b/tests/ad/fdiv.fut index 768263e783..001ae59798 100644 --- a/tests/ad/fdiv.fut +++ b/tests/ad/fdiv.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp f_vjp --- compiled input { 5.0 7.0 } +-- input { 5.0 7.0 } -- output { 0.14285 -0.102041 } def f (x,y) = x / y : f64 diff --git a/tests/ad/fmul.fut b/tests/ad/fmul.fut index 443079e818..a2643d40b9 100644 --- a/tests/ad/fmul.fut +++ b/tests/ad/fmul.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp f_vjp --- compiled input { 5.0 7.0 } +-- input { 5.0 7.0 } -- output { 7.0 5.0 } def f (x,y) = x * y : f64 diff --git a/tests/ad/for0.fut b/tests/ad/for0.fut index 984ad7678a..9551b8dd69 100644 --- a/tests/ad/for0.fut +++ b/tests/ad/for0.fut @@ -2,13 +2,13 @@ def pow y x = loop acc = 1 for _i < y do acc * x -- == -- entry: prim --- compiled input { 3 4 } output { 64 } --- compiled input { 9 3 } output { 19683 } +-- input { 3 4 } output { 64 } +-- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp f_vjp --- compiled input { 3 4 } output { 48 } --- compiled input { 9 3 } output { 59049 } +-- input { 3 4 } output { 48 } +-- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 entry f_vjp y x = vjp (pow y) x 1 diff --git a/tests/ad/for1.fut b/tests/ad/for1.fut index 87fcf475d1..b131d2e2ec 100644 --- a/tests/ad/for1.fut +++ b/tests/ad/for1.fut @@ -3,12 +3,12 @@ def pow_list [n] y (xs :[n]i32) = loop accs = (replicate n 1) for _i < y do -- == -- entry: prim --- compiled input { 3 [1,2,3] } output { [1,8,27] } +-- input { 3 [1,2,3] } output { [1,8,27] } entry prim y xs = pow_list y xs -- == -- entry: f_vjp f_jvp --- compiled input { 3 [1,2,3] } +-- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], -- [0,0,27]] diff --git a/tests/ad/for2.fut b/tests/ad/for2.fut index 3051951673..235e9b49cb 100644 --- a/tests/ad/for2.fut +++ b/tests/ad/for2.fut @@ -4,12 +4,12 @@ def mult_list xs = -- == -- entry: prim --- compiled input { [11,5,13] } output { 169 } +-- input { [11,5,13] } output { 169 } entry prim = mult_list -- == -- entry: f_jvp f_vjp --- compiled input { [11,5,13] } output { [0,0,26] } +-- input { [11,5,13] } output { [0,0,26] } entry f_jvp [n] (xs :[n]i32) = tabulate n (\i -> jvp mult_list xs (replicate n 0 with [i] = 1)) entry f_vjp [n] (xs: [n]i32) = vjp mult_list xs 1 diff --git a/tests/ad/for3.fut b/tests/ad/for3.fut index a2d1e2b543..4c697e7d57 100644 --- a/tests/ad/for3.fut +++ b/tests/ad/for3.fut @@ -6,12 +6,12 @@ def square [n] (xs: [n]i32) = -- == -- entry: prim --- compiled input { [1,2,3,4,5] } output { [1,4,9,16,25] } +-- input { [1,2,3,4,5] } output { [1,4,9,16,25] } entry prim [n] (xs: [n]i32) = square xs -- == -- entry: f_jvp f_vjp --- compiled input { [1,2,3,4,5] } +-- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], -- [0,0,6,0,0], diff --git a/tests/ad/fwd/acc0.fut b/tests/ad/fwd/acc0.fut index ca07650ea5..eef91fe350 100644 --- a/tests/ad/fwd/acc0.fut +++ b/tests/ad/fwd/acc0.fut @@ -4,7 +4,7 @@ def f (acc : *acc([]i32)) i = write acc i (i32.i64 i) -- square entries -- == -- entry: prim --- compiled input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } +-- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } -- output { [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] } entry prim [n] (xs: [n]i32) = @@ -13,7 +13,7 @@ entry prim [n] (xs: [n]i32) = -- == -- entry: f_jvp --- compiled input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } +-- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } -- output { [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] } entry f_jvp (xs: *[]i32) = jvp prim xs (replicate 10 1) diff --git a/tests/ad/fwd/for0.fut b/tests/ad/fwd/for0.fut index 2115e768e4..0afa061764 100644 --- a/tests/ad/fwd/for0.fut +++ b/tests/ad/fwd/for0.fut @@ -2,14 +2,14 @@ def pow y x = loop acc = 1 for i < y do acc * x -- == -- entry: prim --- compiled input { 3 4 } output { 64 } --- compiled input { 9 3 } output { 19683 } +-- input { 3 4 } output { 64 } +-- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp --- compiled input { 3 4 } output { 48 } --- compiled input { 9 3 } output { 59049 } +-- input { 3 4 } output { 48 } +-- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 diff --git a/tests/ad/fwd/for1.fut b/tests/ad/fwd/for1.fut index bb8dcf76f2..cff3fdcbf6 100644 --- a/tests/ad/fwd/for1.fut +++ b/tests/ad/fwd/for1.fut @@ -3,14 +3,14 @@ def pow y x = loop acc = 1 for i in [y, y*y] do acc * x * i -- == -- entry: prim --- compiled input { 3 4 } output { 432 } --- compiled input { 9 3 } output { 6561 } +-- input { 3 4 } output { 432 } +-- input { 9 3 } output { 6561 } entry prim y x = pow y x -- == -- entry: f_jvp --- compiled input { 3 4 } output { 216 } --- compiled input { 9 3 } output { 4374 } +-- input { 3 4 } output { 216 } +-- input { 9 3 } output { 4374 } entry f_jvp y x = jvp (pow y) x 1 diff --git a/tests/ad/fwd/map0.fut b/tests/ad/fwd/map0.fut index ec782d697a..6a34263f71 100644 --- a/tests/ad/fwd/map0.fut +++ b/tests/ad/fwd/map0.fut @@ -2,6 +2,6 @@ def f x = map (*(x*x)) [0,1,2] -- == -- entry: f_jvp --- compiled input { 2 } output { [0, 4, 8] } --- compiled input { 4 } output { [0, 8, 16] } +-- input { 2 } output { [0, 4, 8] } +-- input { 4 } output { [0, 8, 16] } entry f_jvp x = jvp f x 1 diff --git a/tests/ad/fwd/red0.fut b/tests/ad/fwd/red0.fut index c1dd097733..7ad6caa19e 100644 --- a/tests/ad/fwd/red0.fut +++ b/tests/ad/fwd/red0.fut @@ -2,6 +2,6 @@ def f x = reduce (*) 1 [1,2,x,4] -- == -- entry: f_jvp --- compiled input { 3 } output { 8 } --- compiled input { 10 } output { 8 } +-- input { 3 } output { 8 } +-- input { 10 } output { 8 } entry f_jvp x = jvp f x 1 diff --git a/tests/ad/fwd/scatter0.fut b/tests/ad/fwd/scatter0.fut index 42a6455b4c..8f52d2fdd9 100644 --- a/tests/ad/fwd/scatter0.fut +++ b/tests/ad/fwd/scatter0.fut @@ -4,5 +4,5 @@ def f x = -- == -- entry: f_jvp --- compiled input { 5 } output { [1, 10, 75, 0, 0] } +-- input { 5 } output { [1, 10, 75, 0, 0] } entry f_jvp x = jvp f x 1 diff --git a/tests/ad/fwd/while0.fut b/tests/ad/fwd/while0.fut index 9eb6b1bb9c..13ce78c3bc 100644 --- a/tests/ad/fwd/while0.fut +++ b/tests/ad/fwd/while0.fut @@ -3,14 +3,14 @@ def pow y x = let (_, res) = loop (i, acc) = (0, 1) while i < y do in res -- == -- entry: prim --- compiled input { 3 4 } output { 64 } --- compiled input { 9 3 } output { 19683 } +-- input { 3 4 } output { 64 } +-- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp --- compiled input { 3 4 } output { 48 } --- compiled input { 9 3 } output { 59049 } +-- input { 3 4 } output { 48 } +-- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 diff --git a/tests/ad/gather0.fut b/tests/ad/gather0.fut index 546d46585e..0c1e485889 100644 --- a/tests/ad/gather0.fut +++ b/tests/ad/gather0.fut @@ -1,12 +1,12 @@ -- == -- entry: fwd_J rev_J --- compiled input { [4.0,3.0,2.0,1.0] [0i64,1i64,2i64,3i64] } +-- input { [4.0,3.0,2.0,1.0] [0i64,1i64,2i64,3i64] } -- output { [[1.0, 0.0, 0.0, 0.0], -- [0.0, 1.0, 0.0, 0.0], -- [0.0, 0.0, 1.0, 0.0], -- [0.0, 0.0, 0.0, 1.0]] -- } --- compiled input { [4.0,3.0,2.0,1.0] [0i64,0i64,3i64,3i64] } +-- input { [4.0,3.0,2.0,1.0] [0i64,0i64,3i64,3i64] } -- output { [[1.0, 0.0, 0.0, 0.0], -- [1.0, 0.0, 0.0, 0.0], -- [0.0, 0.0, 0.0, 1.0], diff --git a/tests/ad/gather1.fut b/tests/ad/gather1.fut index 182811306d..1e7d593074 100644 --- a/tests/ad/gather1.fut +++ b/tests/ad/gather1.fut @@ -1,6 +1,6 @@ -- == -- entry: fwd_J rev_J --- compiled input +-- input -- { -- [[1.0,2.0],[3.0,4.0]] [1i64, 0i64, 1i64, 1i64] -- } diff --git a/tests/ad/gather2.fut b/tests/ad/gather2.fut index 7cdc67fed1..1854fc7eee 100644 --- a/tests/ad/gather2.fut +++ b/tests/ad/gather2.fut @@ -1,6 +1,6 @@ -- == -- entry: fwd_J rev_J --- compiled input +-- input -- { -- [1.0,2.0,3.0,4.0] -- [[1i64, 3i64], [2i64, 2i64]] diff --git a/tests/ad/genred-opt/matmul.fut b/tests/ad/genred-opt/matmul.fut index 5a4ef7805e..bd2ec711c3 100644 --- a/tests/ad/genred-opt/matmul.fut +++ b/tests/ad/genred-opt/matmul.fut @@ -1,6 +1,6 @@ -- == -- entry: fwd_J rev_J --- compiled input +-- input -- { -- [[1.0,2.0],[3.0,4.0]] [[5.0,6.0],[7.0,8.0]] -- } diff --git a/tests/ad/if0.fut b/tests/ad/if0.fut index b7bdc70e85..b8149724cb 100644 --- a/tests/ad/if0.fut +++ b/tests/ad/if0.fut @@ -1,8 +1,8 @@ -- == -- entry: f_jvp --- compiled input { true 5.0 7.0 } +-- input { true 5.0 7.0 } -- output { 7.0 5.0 } --- compiled input { false 5.0 7.0 } +-- input { false 5.0 7.0 } -- output { 0.14285 -0.102041 } def f (b, x, y) : f64 = @@ -14,9 +14,9 @@ entry f_jvp b x y = -- == -- entry: f_vjp --- compiled input { true 5.0 7.0 } +-- input { true 5.0 7.0 } -- output { false 7.0 5.0 } --- compiled input { false 5.0 7.0 } +-- input { false 5.0 7.0 } -- output { false 0.14285 -0.102041 } entry f_vjp b x y = diff --git a/tests/ad/if1.fut b/tests/ad/if1.fut index f3d5189d90..802a16bed2 100644 --- a/tests/ad/if1.fut +++ b/tests/ad/if1.fut @@ -1,8 +1,8 @@ -- == -- entry: f_jvp --- compiled input { false 5.0 } +-- input { false 5.0 } -- output { 2.0 } --- compiled input { true 5.0 } +-- input { true 5.0 } -- output { 11.0 } def f (b, x) : f64 = @@ -15,9 +15,9 @@ entry f_jvp b x = -- == -- entry: f_vjp --- compiled input { false 5.0 } +-- input { false 5.0 } -- output { false 2.0 } --- compiled input { true 5.0 } +-- input { true 5.0 } -- output { false 11.0 } entry f_vjp b x = diff --git a/tests/ad/if2.fut b/tests/ad/if2.fut index b6e45fd68b..d7d1ddd2b4 100644 --- a/tests/ad/if2.fut +++ b/tests/ad/if2.fut @@ -1,8 +1,8 @@ -- == -- entry: f_jvp --- compiled input { [1.0,2.0,3.0] } +-- input { [1.0,2.0,3.0] } -- output { [0.0, 3.0, 2.0] } --- compiled input { [-1.0,2.0,3.0] } +-- input { [-1.0,2.0,3.0] } -- output { [3.0, 0.0, -1.0] } -- structure { If/Replicate 0 } @@ -18,9 +18,9 @@ entry f_jvp x = -- == -- entry: f_vjp --- compiled input { [1.0,2.0,3.0] } +-- input { [1.0,2.0,3.0] } -- output { [0.0, 3.0, 2.0] } --- compiled input { [-1.0,2.0,3.0] } +-- input { [-1.0,2.0,3.0] } -- output { [3.0, 0.0, -1.0] } entry f_vjp x = diff --git a/tests/ad/imul.fut b/tests/ad/imul.fut index b78f43798a..97a531d57f 100644 --- a/tests/ad/imul.fut +++ b/tests/ad/imul.fut @@ -1,5 +1,5 @@ -- Check the absence of integer overflow. -- == --- compiled input { 2000000000i32 2000000000i32 } output { -294967296i32 } +-- input { 2000000000i32 2000000000i32 } output { -294967296i32 } def main x y : i32 = vjp (\x -> x * y) x 2 diff --git a/tests/ad/issue1577.fut b/tests/ad/issue1577.fut index e7fb3277e0..b684a0b7a2 100644 --- a/tests/ad/issue1577.fut +++ b/tests/ad/issue1577.fut @@ -1,6 +1,6 @@ -- == -- entry: main --- compiled input { [1i64,1i64,3i64,3i64] [1,2,3,4] } +-- input { [1i64,1i64,3i64,3i64] [1,2,3,4] } -- output { [0,3,0,7,0] } let red [n] (is: [n]i64) (vs: [n]i32) = diff --git a/tests/ad/issue1604.fut b/tests/ad/issue1604.fut index 8c1a9ae5f1..c4317acdc6 100644 --- a/tests/ad/issue1604.fut +++ b/tests/ad/issue1604.fut @@ -1,6 +1,6 @@ -- == -- entry: f_vjp --- compiled input { [1, 2, 3] } +-- input { [1, 2, 3] } -- output { [9, 9, 9] } def f [n] (xs: [n]i32) = diff --git a/tests/ad/issue1879.fut b/tests/ad/issue1879.fut index 4959d57fb1..453ee2a419 100644 --- a/tests/ad/issue1879.fut +++ b/tests/ad/issue1879.fut @@ -1,8 +1,8 @@ -- == -- entry: main_ad --- compiled input { [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] } +-- input { [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] } -- output { [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] } --- compiled input { [[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]] } +-- input { [[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]] } -- output { [[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]] } def f [n] (xs: [n][3]f64) = diff --git a/tests/ad/lighthouse.fut b/tests/ad/lighthouse.fut index dac00cbee3..0080b42d7d 100644 --- a/tests/ad/lighthouse.fut +++ b/tests/ad/lighthouse.fut @@ -1,7 +1,7 @@ -- From [Griewank 2008]. -- == -- entry: lighthouse_jvp lighthouse_vjp --- compiled input { 2.0 1.5 0.4 2.1 } +-- input { 2.0 1.5 0.4 2.1 } -- output { 2.902513633461043f64 -15.102798701184362f64 95.71780341846966f64 18.23196255589898f64 -- 4.353770450191565f64 -16.849170784854458f64 143.57670512770449f64 27.347943833848472f64 -- } diff --git a/tests/ad/map0.fut b/tests/ad/map0.fut index ac1b10692d..7175ff875c 100644 --- a/tests/ad/map0.fut +++ b/tests/ad/map0.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [1,2,3] [3,2,1] } +-- input { [1,2,3] [3,2,1] } -- output { [6,4,2] } entry rev = vjp (map (*2i32)) diff --git a/tests/ad/map1.fut b/tests/ad/map1.fut index b6f9bb5786..d6011631eb 100644 --- a/tests/ad/map1.fut +++ b/tests/ad/map1.fut @@ -1,7 +1,7 @@ -- -- == -- entry: rev --- compiled input { [[1.0,2.0,3.0,4.0],[1.0,2.0,3.0,4.0]] [1.0,2.0] } +-- input { [[1.0,2.0,3.0,4.0],[1.0,2.0,3.0,4.0]] [1.0,2.0] } -- output {[[24.0, 12.0, 8.0, 6.0], -- [48.0, 24.0, 16.0, 12.0]] } diff --git a/tests/ad/map2.fut b/tests/ad/map2.fut index 65a133135f..75113fc866 100644 --- a/tests/ad/map2.fut +++ b/tests/ad/map2.fut @@ -1,7 +1,7 @@ -- Map with free variable. -- == -- entry: fwd_J rev_J --- compiled input { 2.0 [1.0,2.0,3.0] } +-- input { 2.0 [1.0,2.0,3.0] } -- output { [1.0,2.0,3.0] } def onehot n i : [n]f64 = diff --git a/tests/ad/map3.fut b/tests/ad/map3.fut index 9c4f0e0bfd..4ac46b350d 100644 --- a/tests/ad/map3.fut +++ b/tests/ad/map3.fut @@ -1,6 +1,6 @@ -- == -- entry: fwd rev --- compiled input { 1i32 [1i32,2i32,3i32] } +-- input { 1i32 [1i32,2i32,3i32] } -- output { [1i32,2i32,3i32] } entry fwd [n] (x: i32) (xs: [n]i32) = diff --git a/tests/ad/map4.fut b/tests/ad/map4.fut index 2ebc0b2a5d..3c0b860039 100644 --- a/tests/ad/map4.fut +++ b/tests/ad/map4.fut @@ -1,7 +1,7 @@ -- An array is both a 'map' input and a free variable in the lambda. -- == -- entry: fwd_J rev_J --- compiled input { [1,2,3] } +-- input { [1,2,3] } -- output { -- [[[2, 0, 0], [1, 1, 0], [1, 0, 1]], [[1, 1, 0], [0, 2, 0], [0, 1, 1]], [[1, 0, 1], [0, 1, 1], [0, 0, 2]]] -- } diff --git a/tests/ad/map5.fut b/tests/ad/map5.fut index 06b360a8fc..39e411523d 100644 --- a/tests/ad/map5.fut +++ b/tests/ad/map5.fut @@ -1,7 +1,7 @@ -- Map with free array variable. -- == -- entry: fwd_J rev_J --- compiled input { [[1,2,3],[4,5,6]] [0,0] } +-- input { [[1,2,3],[4,5,6]] [0,0] } -- output { [[1, 0], [0, 1]] } def onehot n i : [n]i32 = diff --git a/tests/ad/map6.fut b/tests/ad/map6.fut index 5a34d1886d..b7a40aafa3 100644 --- a/tests/ad/map6.fut +++ b/tests/ad/map6.fut @@ -1,7 +1,7 @@ -- #1878 -- == -- entry: fwd_J rev_J --- compiled input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } +-- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } -- output { [[0.0, 2.0, 3.0, 4.0], -- [0.0, 0.0, 1.0, 1.0], -- [0.0, 0.0, 0.0, 1.0], diff --git a/tests/ad/map7.fut b/tests/ad/map7.fut index 7197e0c7d9..01f5c3e249 100644 --- a/tests/ad/map7.fut +++ b/tests/ad/map7.fut @@ -2,7 +2,7 @@ -- has active free variables. -- == -- entry: fwd_J rev_J --- compiled input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } +-- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } -- output { [0.0, 0.0, 0.0, 0.0, -4.0, 0.0, 0.0, 0.0] } def obj (x : [8]f64) = diff --git a/tests/ad/matmul.fut b/tests/ad/matmul.fut index 77b25ae7b7..fb5a9f6090 100644 --- a/tests/ad/matmul.fut +++ b/tests/ad/matmul.fut @@ -1,6 +1,6 @@ -- == -- entry: fwd_J rev_J --- compiled input +-- input -- { -- [[1.0,2.0],[3.0,4.0]] [[5.0,6.0],[7.0,8.0]] -- } diff --git a/tests/ad/maximum.fut b/tests/ad/maximum.fut index 7645822893..f00a03889b 100644 --- a/tests/ad/maximum.fut +++ b/tests/ad/maximum.fut @@ -1,8 +1,8 @@ -- == -- entry: rev fwd --- compiled input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } +-- input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } -- output { [0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0] } --- compiled input { [1.0, 1.0] } +-- input { [1.0, 1.0] } -- output { [1.0, 0.0] } -- structure { /Screma 2 } diff --git a/tests/ad/minimum.fut b/tests/ad/minimum.fut index deb1d94bec..aa927d6d61 100644 --- a/tests/ad/minimum.fut +++ b/tests/ad/minimum.fut @@ -1,8 +1,8 @@ -- == -- entry: rev fwd --- compiled input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0] } +-- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0] } -- output { [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] } --- compiled input { [1.0, 1.0] } +-- input { [1.0, 1.0] } -- output { [1.0, 0.0] } entry rev [n] (xs: [n]f64) = diff --git a/tests/ad/minmax.fut b/tests/ad/minmax.fut index edb9ff16fb..f2ac29d51b 100644 --- a/tests/ad/minmax.fut +++ b/tests/ad/minmax.fut @@ -1,6 +1,6 @@ -- == -- entry: rev fwd --- compiled input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } +-- input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } -- output { [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -- [0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0] -- } diff --git a/tests/ad/negate.fut b/tests/ad/negate.fut new file mode 100644 index 0000000000..ce81171e4f --- /dev/null +++ b/tests/ad/negate.fut @@ -0,0 +1,8 @@ +-- == +-- entry: fwd rev +-- input { 1f32 } output { -1f32 } + +def f x : f32 = -x + +entry fwd x = jvp f x 1 +entry rev x = jvp f x 1 diff --git a/tests/ad/nested0.fut b/tests/ad/nested0.fut index 51ff7a5296..9873cac89e 100644 --- a/tests/ad/nested0.fut +++ b/tests/ad/nested0.fut @@ -1,6 +1,6 @@ -- == -- entry: f_vjp --- compiled input { [1,2,3] } +-- input { [1,2,3] } -- output { [24,48,72] } def f [n] (xs: [n]i32) = map (\x -> x * x * x * x) xs diff --git a/tests/ad/nested1.fut b/tests/ad/nested1.fut index fbfa9402d1..51bf3052ff 100644 --- a/tests/ad/nested1.fut +++ b/tests/ad/nested1.fut @@ -1,6 +1,6 @@ -- == -- entry: f_vjp --- compiled input { [1,2,3] [0,1,2] } +-- input { [1,2,3] [0,1,2] } -- output { [6,12,18] [0,0,0] } def f [n] (xsis: ([n]i32, [n]i32)) = let (xs, is) = xsis diff --git a/tests/ad/nested2.fut b/tests/ad/nested2.fut index d70248fba8..795f2ece21 100644 --- a/tests/ad/nested2.fut +++ b/tests/ad/nested2.fut @@ -1,6 +1,6 @@ -- == -- entry: f_vjp --- compiled input { [1,2,3] [0,1,2] } +-- input { [1,2,3] [0,1,2] } -- output { [24,48,72] [0,0,0] } def f [n] (xsis: ([n]i32, [n]i32)) = let (xs, is) = xsis diff --git a/tests/ad/nested3.fut b/tests/ad/nested3.fut index 536efba99a..6a3af549bb 100644 --- a/tests/ad/nested3.fut +++ b/tests/ad/nested3.fut @@ -1,6 +1,6 @@ -- == -- entry: f_vjp --- compiled input { [[1,2,3],[1,2,3],[1,2,3]] [1,2,3]} +-- input { [[1,2,3],[1,2,3],[1,2,3]] [1,2,3]} -- output { [[1,1,1],[1,1,1],[1,1,1]] [3,3,3] } def f [n] (xssys : ([n][n]i32, [n]i32)) = let (xss,ys) = xssys diff --git a/tests/ad/nested4.fut b/tests/ad/nested4.fut index 9e4cf3c5e0..4aa334cc5d 100644 --- a/tests/ad/nested4.fut +++ b/tests/ad/nested4.fut @@ -1,6 +1,6 @@ -- == -- entry: f_vjp --- compiled input { [[1,2,3],[1,2,3],[1,2,3]] [0,1,2] [0,1,2]} +-- input { [[1,2,3],[1,2,3],[1,2,3]] [0,1,2] [0,1,2]} -- output { [[6,12,18],[6,12,18],[6,12,18]] [0,0,0] [0,0,0] } def f [n] (xssisjs: ([n][n]i32, [n]i32, [n]i32)) = let (xss, is, js) = xssisjs diff --git a/tests/ad/not.fut b/tests/ad/not.fut new file mode 100644 index 0000000000..fba3f6a621 --- /dev/null +++ b/tests/ad/not.fut @@ -0,0 +1,8 @@ +-- == +-- entry: fwd rev +-- input { true } output { true } + +def f x : bool = !x + +entry fwd x = jvp f x true +entry rev x = jvp f x true diff --git a/tests/ad/rearrange0.fut b/tests/ad/rearrange0.fut index 000603e67f..dce287e292 100644 --- a/tests/ad/rearrange0.fut +++ b/tests/ad/rearrange0.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp --- compiled input { [[1,2],[3,4]] } +-- input { [[1,2],[3,4]] } -- output { [[1,3],[2,4]] } entry f_jvp (xss: [][]i32) = @@ -8,7 +8,7 @@ entry f_jvp (xss: [][]i32) = -- == -- entry: f_vjp --- compiled input { [[1,2],[3,4]] } +-- input { [[1,2],[3,4]] } -- output { [[1,3],[2,4]] } entry f_vjp (xss: [][]i32) = diff --git a/tests/ad/reduce0.fut b/tests/ad/reduce0.fut index c9b331737c..87b9aec1cd 100644 --- a/tests/ad/reduce0.fut +++ b/tests/ad/reduce0.fut @@ -1,7 +1,7 @@ -- Simple reduce with multiplication -- == -- entry: rev --- compiled input { [1.0f32, 2.0f32, 3.0f32, 4.0f32] 1.0f32 } output { [24.0f32, 12.0f32, 8.0f32, 6.0f32] 24.0f32 } +-- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32] 1.0f32 } output { [24.0f32, 12.0f32, 8.0f32, 6.0f32] 24.0f32 } def red_mult [n] (xs: [n]f32, c: f32) : f32 = reduce (*) 1 xs * c diff --git a/tests/ad/reduce1.fut b/tests/ad/reduce1.fut index 85de4de3ab..3e481b7e62 100644 --- a/tests/ad/reduce1.fut +++ b/tests/ad/reduce1.fut @@ -1,7 +1,7 @@ -- Reduce with a fancier operator. -- == -- entry: rev --- compiled input { [1.0,2.0,3.0] [2.0,3.0,4.0] [3.0,4.0,5.0] [4.0,5.0,6.0] } +-- input { [1.0,2.0,3.0] [2.0,3.0,4.0] [3.0,4.0,5.0] [4.0,5.0,6.0] } -- output { [47.0, 28.0, 32.0] -- [83.0, 44.0, 32.0] -- [47.0, 42.0, 42.0] diff --git a/tests/ad/reduce2.fut b/tests/ad/reduce2.fut index 3f27c48c16..b827e4eb59 100644 --- a/tests/ad/reduce2.fut +++ b/tests/ad/reduce2.fut @@ -2,7 +2,7 @@ -- == -- tags { no_ispc } -- entry: fwd rev --- compiled input { [3f64, 1f64, 5f64] } output { [-1.000000f64, -1.000000f64, -1.000000f64] } +-- input { [3f64, 1f64, 5f64] } output { [-1.000000f64, -1.000000f64, -1.000000f64] } def sumBy 'a (f : a -> f64) (xs : []a) : f64 = map f xs |> f64.sum diff --git a/tests/ad/reduce_by_index0.fut b/tests/ad/reduce_by_index0.fut index 8900cd2784..61f26d1270 100644 --- a/tests/ad/reduce_by_index0.fut +++ b/tests/ad/reduce_by_index0.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp --- compiled input { [0i64,1i64,2i64,3i64] [1f64,2f64,3f64,4f64] } +-- input { [0i64,1i64,2i64,3i64] [1f64,2f64,3f64,4f64] } -- output { [[1f64,0f64,0f64,0f64],[0f64,1f64,0f64,0f64],[0f64,0f64,1f64,0f64],[0f64,0f64,0f64,1f64]] } def f [n] (is: [n]i64) (vs: [n]f64) = hist (+) 0 4 is (map (+2) vs) diff --git a/tests/ad/reducebyindex0.fut b/tests/ad/reducebyindex0.fut index 94e0d50539..65ff4c4d03 100644 --- a/tests/ad/reducebyindex0.fut +++ b/tests/ad/reducebyindex0.fut @@ -1,6 +1,6 @@ -- == -- entry: main --- compiled input { +-- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] diff --git a/tests/ad/reducebyindex2.fut b/tests/ad/reducebyindex2.fut index ba48276de7..85f7ad4334 100644 --- a/tests/ad/reducebyindex2.fut +++ b/tests/ad/reducebyindex2.fut @@ -1,5 +1,5 @@ -- == --- compiled input { +-- input { -- [0i64,1i64,2i64,3i64,2i64,1i64,0i64,1i64,2i64] -- [0f64,1f64,2f64,3f64] -- [2f64,3f64,4f64,5f64,6f64,0f64,8f64,9f64,1f64] diff --git a/tests/ad/reducebyindex3.fut b/tests/ad/reducebyindex3.fut index cce88f1bcb..3bafd9e7b9 100644 --- a/tests/ad/reducebyindex3.fut +++ b/tests/ad/reducebyindex3.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { +-- input { -- [0i64,1i64,2i64,1i64,0i64,1i64,2i64] -- [1f64,2f64,3f64,4f64,5f64,6f64,7f64] } -- output { diff --git a/tests/ad/reducebyindex4.fut b/tests/ad/reducebyindex4.fut index 132a3b7cd8..93f5ba783d 100644 --- a/tests/ad/reducebyindex4.fut +++ b/tests/ad/reducebyindex4.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { +-- input { -- [ 0i64, 1i64, 2i64, 1i64, 0i64, 1i64, 2i64, 1i64, 0i64] -- [ 1f32, 2f32, 3f32, 4f32, 5f32, 6f32, 7f32, 8f32, 9f32] -- [10f32,11f32,12f32,13f32,14f32,15f32,16f32,17f32,18f32] } diff --git a/tests/ad/reducebyindex6.fut b/tests/ad/reducebyindex6.fut index 5d0a577e87..22c38ea8b5 100644 --- a/tests/ad/reducebyindex6.fut +++ b/tests/ad/reducebyindex6.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { +-- input { -- [1i64,-3i64,1i64,5i64,1i64,-3i64,1i64,5i64] -- [1f32, 2f32,3f32,4f32,5f32, 6f32,7f32,8f32] } -- output { diff --git a/tests/ad/reducebyindexadd0.fut b/tests/ad/reducebyindexadd0.fut index a0dfe42957..b936f86d7a 100644 --- a/tests/ad/reducebyindexadd0.fut +++ b/tests/ad/reducebyindexadd0.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [2f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,10f32] } output { [5f32,0f32,0f32,0f32,0f32] } +-- input { [2f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,10f32] } output { [5f32,0f32,0f32,0f32,0f32] } -- checks original dst is used def red_add [n][m] (is: [n]i64) (vs: [n]f32) (dst: [m]f32) = diff --git a/tests/ad/reducebyindexadd1.fut b/tests/ad/reducebyindexadd1.fut index 4db6865587..a9f55b6925 100644 --- a/tests/ad/reducebyindexadd1.fut +++ b/tests/ad/reducebyindexadd1.fut @@ -1,6 +1,6 @@ -- == -- entry: main --- compiled input { +-- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] } diff --git a/tests/ad/reducebyindexadd2.fut b/tests/ad/reducebyindexadd2.fut index 0030a231f8..8320393cac 100644 --- a/tests/ad/reducebyindexadd2.fut +++ b/tests/ad/reducebyindexadd2.fut @@ -1,6 +1,6 @@ -- == -- entry: main --- compiled input { +-- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] @@ -9,7 +9,7 @@ -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] -- [5f32,6f32,3f32,5f32,6f32,3f32,1f32,1f32,5f32,6f32,2f32,5f32,2f32,4f32,4f32,2f32,0f32,0f32] -- [4f32,14f32,13f32,13f32,29f32,30f32,9f32,0f32] } --- compiled input { +-- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] diff --git a/tests/ad/reducebyindexadd3.fut b/tests/ad/reducebyindexadd3.fut index 0e12d0422f..36fdbae707 100644 --- a/tests/ad/reducebyindexadd3.fut +++ b/tests/ad/reducebyindexadd3.fut @@ -1,6 +1,6 @@ -- == -- entry: main --- compiled input { +-- input { -- [0i64,0i64,0i64,1i64,1i64,2i64,2i64,2i64,2i64] -- [[1f32,2f32],[0f32,4f32],[5f32,0f32],[9f32,0f32]] -- [[1f32,3f32],[2f32,4f32],[18f32,5f32],[6f32,0f32],[7f32,9f32],[0f32,14f32],[11f32,0f32],[0f32,16f32],[13f32,17f32]] diff --git a/tests/ad/reducebyindexadd4.fut b/tests/ad/reducebyindexadd4.fut index 1146ced5fa..704663dba5 100644 --- a/tests/ad/reducebyindexadd4.fut +++ b/tests/ad/reducebyindexadd4.fut @@ -1,6 +1,6 @@ -- == -- entry: main --- compiled input { +-- input { -- [0i64,0i64,0i64,1i64,1i64,1i64,1i64] -- [[[1f32,2f32],[0f32,4f32]],[[5f32,0f32],[9f32,0f32]]] -- [[[1f32,3f32],[6f32,0f32]],[[2f32,4f32],[7f32,9f32]],[[18f32,5f32],[19f32,20f32]], diff --git a/tests/ad/reducebyindexminmax0.fut b/tests/ad/reducebyindexminmax0.fut index cb33922641..4510e195d8 100644 --- a/tests/ad/reducebyindexminmax0.fut +++ b/tests/ad/reducebyindexminmax0.fut @@ -1,13 +1,13 @@ -- == -- entry: rev --- compiled input { [0i64, 1i64, 2i64, 3i64, 4i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0f32,0f32,0f32,0f32,0f32] } --- compiled input { [0i64, 0i64, 0i64, 0i64, 0i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0f32,0f32,0f32,0f32,1f32] } --- compiled input { [0i64, 0i64, 0i64, 0i64, 0i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 3.0f32] } output { [0f32,0f32,0f32,1f32,0f32] } +-- input { [0i64, 1i64, 2i64, 3i64, 4i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0f32,0f32,0f32,0f32,0f32] } +-- input { [0i64, 0i64, 0i64, 0i64, 0i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0f32,0f32,0f32,0f32,1f32] } +-- input { [0i64, 0i64, 0i64, 0i64, 0i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 3.0f32] } output { [0f32,0f32,0f32,1f32,0f32] } -- == -- entry: revp --- compiled input { [0i64, 1i64, 2i64, 3i64, 4i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0i64,0i64,0i64,0i64,0i64] } --- compiled input { [0i64, 0i64, 0i64, 0i64, 0i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0i64,0i64,0i64,0i64,0i64] } +-- input { [0i64, 1i64, 2i64, 3i64, 4i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0i64,0i64,0i64,0i64,0i64] } +-- input { [0i64, 0i64, 0i64, 0i64, 0i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0i64,0i64,0i64,0i64,0i64] } def red_max [n] (is: [n]i64, vs: [n]f32) = reduce_by_index (replicate 5 0) f32.max f32.lowest is vs diff --git a/tests/ad/reducebyindexminmax1.fut b/tests/ad/reducebyindexminmax1.fut index 7e88d8d55a..04533f6907 100644 --- a/tests/ad/reducebyindexminmax1.fut +++ b/tests/ad/reducebyindexminmax1.fut @@ -1,11 +1,11 @@ -- == -- entry: rev --- compiled input { [0f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } --- compiled input { [1f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } --- compiled input { [2f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } --- compiled input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,0f32] } output { [1f32,0f32,0f32,0f32,0f32] } --- compiled input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [1f32,0f32,0f32,0f32,0f32] } --- compiled input { [0f32,1f32,2f32,3f32,4f32] [1i64,2i64,3i64,2i64,1i64] [1f32,2f32,3f32,4f32,5f32] } output { [1f32,0f32,0f32,0f32,0f32] } +-- input { [0f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } +-- input { [1f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } +-- input { [2f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } +-- input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,0f32] } output { [1f32,0f32,0f32,0f32,0f32] } +-- input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [1f32,0f32,0f32,0f32,0f32] } +-- input { [0f32,1f32,2f32,3f32,4f32] [1i64,2i64,3i64,2i64,1i64] [1f32,2f32,3f32,4f32,5f32] } output { [1f32,0f32,0f32,0f32,0f32] } def red_max [n][m] (is: [n]i64) (vs: [n]f32) (dst: [m]f32) = reduce_by_index (copy dst) f32.max f32.lowest is vs diff --git a/tests/ad/reducebyindexminmax10.fut b/tests/ad/reducebyindexminmax10.fut index a1d72cb0e8..6f5a66d87b 100644 --- a/tests/ad/reducebyindexminmax10.fut +++ b/tests/ad/reducebyindexminmax10.fut @@ -1,5 +1,5 @@ -- == --- compiled input { +-- input { -- [0i64,1i64,0i64,1i64] -- [[1f32,2f32],[3f32,4f32]] -- [[1f32,0f32],[5f32,2f32],[-2f32,3f32],[4f32,6f32]] diff --git a/tests/ad/reducebyindexminmax2.fut b/tests/ad/reducebyindexminmax2.fut index 12d4a40fc0..6d78bcfbeb 100644 --- a/tests/ad/reducebyindexminmax2.fut +++ b/tests/ad/reducebyindexminmax2.fut @@ -1,8 +1,8 @@ -- == -- entry: rev --- compiled input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [4f32,3f32,2f32,1f32,0f32] } output { [0f32,0f32,0f32,0f32,0f32] } --- compiled input { [4f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [5f32,4f32,3f32,2f32,1f32] } output { [1f32,0f32,0f32,0f32,0f32] } --- compiled input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [5f32,4f32,3f32,2f32,1f32] } output { [0f32,0f32,0f32,0f32,0f32] } +-- input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [4f32,3f32,2f32,1f32,0f32] } output { [0f32,0f32,0f32,0f32,0f32] } +-- input { [4f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [5f32,4f32,3f32,2f32,1f32] } output { [1f32,0f32,0f32,0f32,0f32] } +-- input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [5f32,4f32,3f32,2f32,1f32] } output { [0f32,0f32,0f32,0f32,0f32] } def red_max [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = reduce_by_index (copy dst) f32.max f32.lowest is vs diff --git a/tests/ad/reducebyindexminmax3.fut b/tests/ad/reducebyindexminmax3.fut index 0c22cef70a..620bd8e162 100644 --- a/tests/ad/reducebyindexminmax3.fut +++ b/tests/ad/reducebyindexminmax3.fut @@ -1,10 +1,10 @@ -- == -- entry: rev --- compiled input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } +-- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } -- output { [0f32,0f32,0f32] 5f32 } --- compiled input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } +-- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } -- output { [3f32,0f32,0f32] 10f32 } --- compiled input { [5f32,1f32,2f32] [0i64,1i64,0i64] [10f32,30f32,2f32] 3f32 } +-- input { [5f32,1f32,2f32] [0i64,1i64,0i64] [10f32,30f32,2f32] 3f32 } -- output { [3f32,0f32,0f32] 10f32 } def red_max [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32, c: f32) = let red = reduce_by_index (copy dst) f32.max f32.lowest is vs diff --git a/tests/ad/reducebyindexminmax4.fut b/tests/ad/reducebyindexminmax4.fut index 1a9c2f9be4..d999be72ae 100644 --- a/tests/ad/reducebyindexminmax4.fut +++ b/tests/ad/reducebyindexminmax4.fut @@ -1,10 +1,10 @@ -- == -- entry: rev --- compiled input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } +-- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } -- output { [3f32,0f32,0f32] 5f32 } --- compiled input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } +-- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } -- output { [0f32,0f32,0f32] 10f32 } --- compiled input { [5f32,1f32,2f32] [0i64,1i64,0i64] [10f32,30f32,2f32] 3f32 } +-- input { [5f32,1f32,2f32] [0i64,1i64,0i64] [10f32,30f32,2f32] 3f32 } -- output { [0f32,0f32,0f32] 10f32 } def red_max [n][m] (vs: [n]f32) (is: [n]i64) (dst: [m]f32, c: f32) = let red = reduce_by_index (copy dst) f32.max f32.lowest is vs diff --git a/tests/ad/reducebyindexminmax5.fut b/tests/ad/reducebyindexminmax5.fut index 4c21bbc988..d8cd922ccb 100644 --- a/tests/ad/reducebyindexminmax5.fut +++ b/tests/ad/reducebyindexminmax5.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [4f32,1f32,2f32] [0i64,0i64,0i64] [5f32,1f32,2f32]} +-- input { [4f32,1f32,2f32] [0i64,0i64,0i64] [5f32,1f32,2f32]} -- output { [13f32,5f32,5f32] } def red_max [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = let red = reduce_by_index (copy dst) f32.max f32.lowest is vs diff --git a/tests/ad/reducebyindexminmax6.fut b/tests/ad/reducebyindexminmax6.fut index 2bba0eaa64..16073113f0 100644 --- a/tests/ad/reducebyindexminmax6.fut +++ b/tests/ad/reducebyindexminmax6.fut @@ -1,7 +1,7 @@ -- == -- entry: rev --- compiled input { [2f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,10f32] } output { [4f32,0f32,0f32,0f32,0f32] } --- compiled input { [10f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [21f32,0f32,0f32,0f32,0f32] } +-- input { [2f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,10f32] } output { [4f32,0f32,0f32,0f32,0f32] } +-- input { [10f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [21f32,0f32,0f32,0f32,0f32] } -- checks original dst is used def red_max [n][m] (is: [n]i64) (vs: [n]f32) (dst: [m]f32) = diff --git a/tests/ad/reducebyindexminmax9.fut b/tests/ad/reducebyindexminmax9.fut index fce732bbab..b443138f89 100644 --- a/tests/ad/reducebyindexminmax9.fut +++ b/tests/ad/reducebyindexminmax9.fut @@ -1,5 +1,5 @@ -- == --- compiled input { +-- input { -- [0i64,1i64,0i64,1i64] -- [[[1f32,2f32],[3f32,4f32]],[[5f32,6f32],[7f32,8f32]]] -- [ [[1f32,0f32],[5f32,2f32]], [[7f32,4f32],[9f32,7f32]], [[-2f32,3f32],[4f32,6f32]], [[1f32,2f32],[5f32,9f32]] ] diff --git a/tests/ad/reducebyindexmul0.fut b/tests/ad/reducebyindexmul0.fut index 9e5810882c..78458f413f 100644 --- a/tests/ad/reducebyindexmul0.fut +++ b/tests/ad/reducebyindexmul0.fut @@ -6,7 +6,7 @@ -- == -- entry: main --- compiled input { +-- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] } diff --git a/tests/ad/reducebyindexmul1.fut b/tests/ad/reducebyindexmul1.fut index 31aa712e6b..6272af58ea 100644 --- a/tests/ad/reducebyindexmul1.fut +++ b/tests/ad/reducebyindexmul1.fut @@ -1,6 +1,6 @@ -- == -- entry: main --- compiled input { +-- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] @@ -9,7 +9,7 @@ -- [2f32,120f32,126f32,0f32,0f32,0f32,7f32,8f32] -- [0f32,0f32,0f32,0f32,0f32,0f32,1f32,2f32,0f32,0f32,48f32,0f32,80f32,144f32,0f32,60f32,0f32,0f32] -- [2f32,120f32,0f32,0f32,0f32,0f32,9f32,0f32] } --- compiled input { +-- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] diff --git a/tests/ad/reducebyindexmul2.fut b/tests/ad/reducebyindexmul2.fut index 0d1278b6e3..27e8fa29be 100644 --- a/tests/ad/reducebyindexmul2.fut +++ b/tests/ad/reducebyindexmul2.fut @@ -1,9 +1,9 @@ -- == -- entry: rev --- compiled input { [2f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,10f32] } output { [244f32,0f32,0f32,0f32,0f32] } --- compiled input { [10f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [0f32,2f32,3f32,4f32,5f32] } output { [20f32,0f32,0f32,0f32,0f32] } --- compiled input { [0f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [120f32,0f32,0f32,0f32,0f32] } --- compiled input { [0f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,0f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } +-- input { [2f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,10f32] } output { [244f32,0f32,0f32,0f32,0f32] } +-- input { [10f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [0f32,2f32,3f32,4f32,5f32] } output { [20f32,0f32,0f32,0f32,0f32] } +-- input { [0f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [120f32,0f32,0f32,0f32,0f32] } +-- input { [0f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,0f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } -- checks original dst is used def red_mul [n][m] (is: [n]i64) (vs: [n]f32) (dst: [m]f32) = diff --git a/tests/ad/reducebyindexmul3.fut b/tests/ad/reducebyindexmul3.fut index 5ec35e8aef..7b61f1e6f4 100644 --- a/tests/ad/reducebyindexmul3.fut +++ b/tests/ad/reducebyindexmul3.fut @@ -1,6 +1,6 @@ -- == -- entry: main --- compiled input { +-- input { -- [0i64,0i64,0i64,1i64,1i64,2i64,2i64,2i64,2i64] -- [[1f32,2f32],[0f32,4f32],[5f32,0f32],[9f32,0f32]] -- [[1f32,3f32],[2f32,4f32],[18f32,5f32],[6f32,0f32],[7f32,9f32],[0f32,14f32],[11f32,0f32],[0f32,16f32],[13f32,17f32]] diff --git a/tests/ad/reducebyindexmul4.fut b/tests/ad/reducebyindexmul4.fut index 6ef9741972..4821a468e3 100644 --- a/tests/ad/reducebyindexmul4.fut +++ b/tests/ad/reducebyindexmul4.fut @@ -1,6 +1,6 @@ -- == -- entry: main --- compiled input { +-- input { -- [0i64,0i64,0i64,1i64,1i64,1i64,1i64] -- [[[1f32,2f32],[0f32,4f32]],[[5f32,0f32],[9f32,0f32]]] -- [[[1f32,3f32],[6f32,0f32]],[[2f32,4f32],[7f32,9f32]],[[18f32,5f32],[19f32,20f32]], diff --git a/tests/ad/reducebyindexvecmin0.fut b/tests/ad/reducebyindexvecmin0.fut index b03a95e682..655e29d021 100644 --- a/tests/ad/reducebyindexvecmin0.fut +++ b/tests/ad/reducebyindexvecmin0.fut @@ -1,6 +1,6 @@ -- == -- entry: vecmin --- compiled input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] +-- input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] -- [[8i32, 5i32, -2i32, 4i32, 6i32], -- [12i32, 8i32, 7i32, 2i32, 6i32], -- [3i32, 9i32, -2i32, 11i32, 1i32], diff --git a/tests/ad/reducebyindexvecmul0.fut b/tests/ad/reducebyindexvecmul0.fut index f93a779e64..3d39085fe3 100644 --- a/tests/ad/reducebyindexvecmul0.fut +++ b/tests/ad/reducebyindexvecmul0.fut @@ -1,6 +1,6 @@ -- == -- entry: vecmul --- compiled input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] +-- input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] -- [[8i32, 5i32, -2i32, 4i32, 6i32], -- [12i32, 8i32, 7i32, 2i32, 6i32], -- [3i32, 9i32, -2i32, 11i32, 1i32], diff --git a/tests/ad/reducemul0.fut b/tests/ad/reducemul0.fut index 42b5e08b7e..2a0dd243a8 100644 --- a/tests/ad/reducemul0.fut +++ b/tests/ad/reducemul0.fut @@ -1,6 +1,6 @@ -- == -- entry: rev fwd --- compiled input { [0.0f32, 2.0f32, 0.0f32, 4.0f32] } output { [0.0f32, 0.0f32, 0.0f32, 0.0f32] } +-- input { [0.0f32, 2.0f32, 0.0f32, 4.0f32] } output { [0.0f32, 0.0f32, 0.0f32, 0.0f32] } def red_mult [n] (xs: [n]f32) : f32 = reduce (*) 1 xs diff --git a/tests/ad/reducemul1.fut b/tests/ad/reducemul1.fut index a013b018a6..c4626ebba6 100644 --- a/tests/ad/reducemul1.fut +++ b/tests/ad/reducemul1.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [1f32, 0f32, 3f32, 4f32] 3.0f32 } output { [0f32, 36f32, 0f32, 0f32] 0f32 } +-- input { [1f32, 0f32, 3f32, 4f32] 3.0f32 } output { [0f32, 36f32, 0f32, 0f32] 0f32 } def red_mult [n] (xs: [n]f32, c: f32) : f32 = reduce (*) 1 xs * c diff --git a/tests/ad/reducemul2.fut b/tests/ad/reducemul2.fut index 25b1202013..34e2dc17e0 100644 --- a/tests/ad/reducemul2.fut +++ b/tests/ad/reducemul2.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [1f32, 0f32, 3f32, 0f32] 3.0f32 } output { [0f32, 0f32, 0f32, 0f32] 0f32 } +-- input { [1f32, 0f32, 3f32, 0f32] 3.0f32 } output { [0f32, 0f32, 0f32, 0f32] 0f32 } def red_mult [n] (xs: [n]f32, c: f32) : f32 = reduce (*) 1 xs * c diff --git a/tests/ad/reducemul3.fut b/tests/ad/reducemul3.fut index 4671268d10..9005763b45 100644 --- a/tests/ad/reducemul3.fut +++ b/tests/ad/reducemul3.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [1f32, 2f32, 3f32, 4f32] 3.0f32 } output { [72f32, 36f32, 24f32, 18f32] 24f32 } +-- input { [1f32, 2f32, 3f32, 4f32] 3.0f32 } output { [72f32, 36f32, 24f32, 18f32] 24f32 } def red_mult [n] (xs: [n]f32, c: f32) : f32 = reduce (*) 1 xs * c diff --git a/tests/ad/reducemul4.fut b/tests/ad/reducemul4.fut index 55b41b04f1..0448d69a7f 100644 --- a/tests/ad/reducemul4.fut +++ b/tests/ad/reducemul4.fut @@ -1,6 +1,6 @@ -- == -- entry: fwd rev --- compiled input { [1f32, 2f32, 3f32, 4f32] } output { [[48f32, 12f32, 8f32, 6f32], [48f32, 48f32, 16f32, 12f32], [72f32, 36f32, 48f32, 18f32], [96f32, 48f32, 32f32, 48f32]] } +-- input { [1f32, 2f32, 3f32, 4f32] } output { [[48f32, 12f32, 8f32, 6f32], [48f32, 48f32, 16f32, 12f32], [72f32, 36f32, 48f32, 18f32], [96f32, 48f32, 32f32, 48f32]] } def fun [n] (as: [n]f32) = let x = reduce (*) 1 as diff --git a/tests/ad/reducevec0.fut b/tests/ad/reducevec0.fut index a0c72aa307..99269c0c35 100644 --- a/tests/ad/reducevec0.fut +++ b/tests/ad/reducevec0.fut @@ -1,6 +1,6 @@ -- == -- entry: rev fwd --- compiled input { +-- input { -- [[[0f32,1f32],[2f32,3f32]], -- [[5f32,1f32],[3f32,0f32]], -- [[0f32,1f32],[4f32,4f32]]] } diff --git a/tests/ad/reducevecmul0.fut b/tests/ad/reducevecmul0.fut index 76f532d1d2..784a354536 100644 --- a/tests/ad/reducevecmul0.fut +++ b/tests/ad/reducevecmul0.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [[0.0f32, 2.0f32, 0.0f32, 4.0f32], [4.0f32, 2.0f32, 0.0f32, 0.0f32]] } output { [[4.000000f32, 0.000000f32, 0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32, 0.000000f32, 0.000000f32]] } +-- input { [[0.0f32, 2.0f32, 0.0f32, 4.0f32], [4.0f32, 2.0f32, 0.0f32, 0.0f32]] } output { [[4.000000f32, 0.000000f32, 0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32, 0.000000f32, 0.000000f32]] } let red_mult [m][n] (xs: [m][n]f32) : [n]f32 = reduce (map2 (*)) (replicate n 1) xs diff --git a/tests/ad/reducevecmul1.fut b/tests/ad/reducevecmul1.fut index 11951b3f4c..bd86aa82aa 100644 --- a/tests/ad/reducevecmul1.fut +++ b/tests/ad/reducevecmul1.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [[1f32, 0f32, 3f32, 4f32], [2f32,2f32,2f32,2f32]] 3.0f32 } output { [[0f32, 0f32, 0f32, 6f32], [0f32, 0f32, 0f32,12f32]] 8f32 } +-- input { [[1f32, 0f32, 3f32, 4f32], [2f32,2f32,2f32,2f32]] 3.0f32 } output { [[0f32, 0f32, 0f32, 6f32], [0f32, 0f32, 0f32,12f32]] 8f32 } def red_mult [m][n] (xs: [m][n]f32, c: f32) = reduce (map2 (*)) (replicate n 1) xs |> map (*c) diff --git a/tests/ad/reducevecmul2.fut b/tests/ad/reducevecmul2.fut index ed3f41784a..caca2daf0e 100644 --- a/tests/ad/reducevecmul2.fut +++ b/tests/ad/reducevecmul2.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [[1f32, 0f32, 3f32, 0f32], [1f32,2f32,3f32,4f32]] 3.0f32 } output { [[0f32, 0f32, 9f32, 0f32], [0f32, 0f32, 9f32, 0f32]] 9f32 } +-- input { [[1f32, 0f32, 3f32, 0f32], [1f32,2f32,3f32,4f32]] 3.0f32 } output { [[0f32, 0f32, 9f32, 0f32], [0f32, 0f32, 9f32, 0f32]] 9f32 } def red_mult [m][n] (xs: [m][n]f32, c: f32) = reduce (map2 (*)) (replicate n 1) xs |> map (*c) diff --git a/tests/ad/reducevecmul3.fut b/tests/ad/reducevecmul3.fut index 1802c90850..58c458dd02 100644 --- a/tests/ad/reducevecmul3.fut +++ b/tests/ad/reducevecmul3.fut @@ -1,6 +1,6 @@ -- == -- entry: rev --- compiled input { [[1f32, 2f32, 3f32, 4f32], [1f32, 1f32, 1f32, 1f32]] 3.0f32 } output { [[0f32, 3f32, 0f32, 0f32],[0f32, 6f32, 0f32, 0f32]] 2f32 } +-- input { [[1f32, 2f32, 3f32, 4f32], [1f32, 1f32, 1f32, 1f32]] 3.0f32 } output { [[0f32, 3f32, 0f32, 0f32],[0f32, 6f32, 0f32, 0f32]] 2f32 } def red_mult [m][n] (xs: [m][n]f32, c: f32) = reduce (map2 (*)) (replicate n 1) xs |> map (*c) diff --git a/tests/ad/replicate0.fut b/tests/ad/replicate0.fut index 5096fa557e..40f402b1da 100644 --- a/tests/ad/replicate0.fut +++ b/tests/ad/replicate0.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp --- compiled input { 3i64 2 } +-- input { 3i64 2 } -- output { [1,1,1] } entry f_jvp n x : []i32 = @@ -8,7 +8,7 @@ entry f_jvp n x : []i32 = -- == -- entry: f_vjp --- compiled input { 3i64 2i64 } +-- input { 3i64 2i64 } -- output { 3i64 } entry f_vjp n x = diff --git a/tests/ad/replicate1.fut b/tests/ad/replicate1.fut index 7a81462968..8c27bdc872 100644 --- a/tests/ad/replicate1.fut +++ b/tests/ad/replicate1.fut @@ -3,7 +3,7 @@ -- == -- entry: f_jvp --- compiled input { 3i64 2i64 } +-- input { 3i64 2i64 } -- output { [0i64,0i64,0i64] } entry f_jvp n x = @@ -11,7 +11,7 @@ entry f_jvp n x = -- == -- entry: f_vjp --- compiled input { 3i64 2i64 } +-- input { 3i64 2i64 } -- output { 0i64 } entry f_vjp n x = diff --git a/tests/ad/replicate2.fut b/tests/ad/replicate2.fut index f47d5a5000..719a6ce7b3 100644 --- a/tests/ad/replicate2.fut +++ b/tests/ad/replicate2.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp --- compiled input { 2i64 3i64 2 } +-- input { 2i64 3i64 2 } -- output { [[1,1,1],[1,1,1]] } entry f_jvp n m x : [][]i32 = @@ -8,7 +8,7 @@ entry f_jvp n m x : [][]i32 = -- == -- entry: f_vjp --- compiled input { 2i64 3i64 2i64 } +-- input { 2i64 3i64 2i64 } -- output { 6i64 } entry f_vjp n m x = diff --git a/tests/ad/reshape0.fut b/tests/ad/reshape0.fut index 3337989a12..bb8f99f7e6 100644 --- a/tests/ad/reshape0.fut +++ b/tests/ad/reshape0.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp --- compiled input { 2i64 2i64 [1,2,3,4] } +-- input { 2i64 2i64 [1,2,3,4] } -- output { [[1,2],[3,4]] } entry f_jvp n m (xs: [n*m]i32) = @@ -8,7 +8,7 @@ entry f_jvp n m (xs: [n*m]i32) = -- == -- entry: f_vjp --- compiled input { 2i64 2i64 [1,2,3,4] } +-- input { 2i64 2i64 [1,2,3,4] } -- output { [1,2,3,4] } entry f_vjp n m (xs: [n*m]i32) = diff --git a/tests/ad/rev_const.fut b/tests/ad/rev_const.fut index ed8f540694..5cef1969eb 100644 --- a/tests/ad/rev_const.fut +++ b/tests/ad/rev_const.fut @@ -1,6 +1,6 @@ -- What happens if a result is constant? -- == --- compiled input { 1f32 2f32 } output { 1f32 1f32 } +-- input { 1f32 2f32 } output { 1f32 1f32 } def main (x: f32) (y: f32) = vjp (\(x',y') -> (x' + y', 0)) (x,y) (1, 0) diff --git a/tests/ad/rev_unused.fut b/tests/ad/rev_unused.fut index fed1c20ad8..e233fc0159 100644 --- a/tests/ad/rev_unused.fut +++ b/tests/ad/rev_unused.fut @@ -1,6 +1,6 @@ -- What happens if not all the parameters are used? -- == --- compiled input { 1f32 2f32 } output { 1f32 0f32 } +-- input { 1f32 2f32 } output { 1f32 0f32 } def main (x: f32) (y: f32) = vjp (\(x',_) -> x' + 2) (x,y) 1 diff --git a/tests/ad/rotate0.fut b/tests/ad/rotate0.fut index c8b1160c2b..811879d840 100644 --- a/tests/ad/rotate0.fut +++ b/tests/ad/rotate0.fut @@ -1,6 +1,6 @@ -- == -- entry: f_jvp --- compiled input { 1i64 [1,2,3,4] } +-- input { 1i64 [1,2,3,4] } -- output { [2,3,4,1] } entry f_jvp k (xs: []i32) = @@ -8,7 +8,7 @@ entry f_jvp k (xs: []i32) = -- == -- entry: f_vjp --- compiled input { 1i64 [1,2,3,4] } +-- input { 1i64 [1,2,3,4] } -- output { [1,2,3,4] } entry f_vjp k (xs: []i32) = diff --git a/tests/ad/scan0.fut b/tests/ad/scan0.fut index 1e7f560ddf..09d06fc0ee 100644 --- a/tests/ad/scan0.fut +++ b/tests/ad/scan0.fut @@ -2,7 +2,7 @@ -- generic case -- == -- entry: fwd_J rev_J --- compiled input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } +-- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } -- output { [[1.0f32, 0.0f32, 0.0f32, 0.0f32, 0.0f32], -- [2.0f32, 1.0f32, 0.0f32, 0.0f32, 0.0f32], -- [6.0f32, 3.0f32, 2.0f32, 0.0f32, 0.0f32], diff --git a/tests/ad/scan1.fut b/tests/ad/scan1.fut index 0fe84bb3fc..ec76faa5b5 100644 --- a/tests/ad/scan1.fut +++ b/tests/ad/scan1.fut @@ -2,7 +2,7 @@ -- addition special case -- == -- entry: fwd_J rev_J --- compiled input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } +-- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } -- output { [[1.0f32, 0.0f32, 0.0f32, 0.0f32, 0.0f32], -- [1.0f32, 1.0f32, 0.0f32, 0.0f32, 0.0f32], -- [1.0f32, 1.0f32, 1.0f32, 0.0f32, 0.0f32], diff --git a/tests/ad/scan2.fut b/tests/ad/scan2.fut index 24f068a5ed..90be756329 100644 --- a/tests/ad/scan2.fut +++ b/tests/ad/scan2.fut @@ -2,7 +2,7 @@ -- special cases: vectorised and addition -- == -- entry: fwd_J rev_J --- compiled input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } +-- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } -- output { [[[1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32]]] } def primal [n][k] (a: [n][k]f32) = diff --git a/tests/ad/scan3.fut b/tests/ad/scan3.fut index 11354d4731..525a6e3839 100644 --- a/tests/ad/scan3.fut +++ b/tests/ad/scan3.fut @@ -2,7 +2,7 @@ -- MatrixMul case -- == -- entry: fwd_J rev_J --- compiled input { [[1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32], [1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32]] } +-- input { [[1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32], [1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32]] } -- output { -- [[[[1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], -- [[0f32, 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], diff --git a/tests/ad/scan4.fut b/tests/ad/scan4.fut index 5cddd64c26..4d15a3775a 100644 --- a/tests/ad/scan4.fut +++ b/tests/ad/scan4.fut @@ -2,7 +2,7 @@ -- ZeroQuadrant case -- == -- entry: fwd_J rev_J --- compiled input { [[1.0f32, 2.0f32, 3.0f32], [4.0f32, 3.0f32, 5.0f32], [3.0f32, 4.0f32, 2.0f32], [4.0f32, 2.0f32, 1.0f32]] } +-- input { [[1.0f32, 2.0f32, 3.0f32], [4.0f32, 3.0f32, 5.0f32], [3.0f32, 4.0f32, 2.0f32], [4.0f32, 2.0f32, 1.0f32]] } -- output { -- [[[1f32, 1f32, 1f32], [0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32], [0f32, 0f32, 0f32]], diff --git a/tests/ad/scan5.fut b/tests/ad/scan5.fut index dc98bd5ea3..d80838474d 100644 --- a/tests/ad/scan5.fut +++ b/tests/ad/scan5.fut @@ -2,7 +2,7 @@ -- Vectorised special case + generic case -- == -- entry: fwd_J rev_J --- compiled input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } +-- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } -- output { -- [[[1f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], -- [[2f32, 2f32], [1f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], diff --git a/tests/ad/scan6.fut b/tests/ad/scan6.fut index 34a2cfc822..9c963db03a 100644 --- a/tests/ad/scan6.fut +++ b/tests/ad/scan6.fut @@ -2,7 +2,7 @@ -- MatrixMul case -- == -- entry: fwd_J rev_J --- compiled input { [[1f32, 2f32], [4f32, 3f32], [3f32, 4f32], [4f32, 2f32]] } +-- input { [[1f32, 2f32], [4f32, 3f32], [3f32, 4f32], [4f32, 2f32]] } -- output { -- [[[[1f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], -- [[0f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]]], @@ -35,7 +35,7 @@ entry rev_J [n] (input: [n][2]f32) = -- == -- entry: fwd_J2 rev_J2 --- compiled input { [[1f32,2f32,3f32,4f32,5f32,6f32],[6f32,5f32,4f32,3f32,2f32,1f32],[4f32,5f32,6f32,1f32,2f32,3f32],[3f32,2f32,1f32,6f32,5f32,4f32]] } +-- no_oclgrind input { [[1f32,2f32,3f32,4f32,5f32,6f32],[6f32,5f32,4f32,3f32,2f32,1f32],[4f32,5f32,6f32,1f32,2f32,3f32],[3f32,2f32,1f32,6f32,5f32,4f32]] } -- output { [[[[1f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[4f32, 3f32, 0f32, 0f32, 0f32, 0f32], [1f32, 0f32, 1f32, 2f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[2f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 1f32, 0f32, 0f32, 1f32, 2f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 4f32, 0f32, 3f32, 0f32], [0f32, 0f32, 3f32, 5f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 4f32, 0f32, 3f32], [0f32, 0f32, 4f32, 6f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 2f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 3f32, 5f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 2f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 4f32, 6f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[26f32, 19f32, 0f32, 0f32, 0f32, 0f32], [6f32, 1f32, 6f32, 12f32, 1f32, 2f32], [1f32, 0f32, 16f32, 9f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[14f32, 9f32, 0f32, 0f32, 0f32, 0f32], [2f32, 3f32, 2f32, 4f32, 3f32, 6f32], [0f32, 1f32, 0f32, 0f32, 16f32, 9f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 26f32, 0f32, 19f32, 0f32], [0f32, 0f32, 18f32, 30f32, 3f32, 5f32], [0f32, 0f32, 27f32, 11f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 26f32, 0f32, 19f32], [0f32, 0f32, 24f32, 36f32, 4f32, 6f32], [0f32, 0f32, 34f32, 14f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 14f32, 0f32, 9f32, 0f32], [0f32, 0f32, 6f32, 10f32, 9f32, 15f32], [0f32, 0f32, 0f32, 0f32, 27f32, 11f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 14f32, 0f32, 9f32], [0f32, 0f32, 8f32, 12f32, 12f32, 18f32], [0f32, 0f32, 0f32, 0f32, 34f32, 14f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[110f32, 73f32, 0f32, 0f32, 0f32, 0f32], [18f32, 19f32, 18f32, 36f32, 19f32, 38f32], [1f32, 6f32, 16f32, 9f32, 96f32, 54f32], [1f32, 0f32, 109f32, 64f32, 0f32, 0f32]], [[186f32, 131f32, 0f32, 0f32, 0f32, 0f32], [38f32, 17f32, 38f32, 76f32, 17f32, 34f32], [5f32, 4f32, 80f32, 45f32, 64f32, 36f32], [0f32, 1f32, 0f32, 0f32, 109f32, 64f32]], [[0f32, 0f32, 110f32, 0f32, 73f32, 0f32], [0f32, 0f32, 54f32, 90f32, 57f32, 95f32], [0f32, 0f32, 27f32, 11f32, 162f32, 66f32], [0f32, 0f32, 173f32, 87f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 110f32, 0f32, 73f32], [0f32, 0f32, 72f32, 108f32, 76f32, 114f32], [0f32, 0f32, 34f32, 14f32, 204f32, 84f32], [0f32, 0f32, 218f32, 110f32, 0f32, 0f32]], [[0f32, 0f32, 186f32, 0f32, 131f32, 0f32], [0f32, 0f32, 114f32, 190f32, 51f32, 85f32], [0f32, 0f32, 135f32, 55f32, 108f32, 44f32], [0f32, 0f32, 0f32, 0f32, 173f32, 87f32]], [[0f32, 0f32, 0f32, 186f32, 0f32, 131f32], [0f32, 0f32, 152f32, 228f32, 68f32, 102f32], [0f32, 0f32, 170f32, 70f32, 136f32, 56f32], [0f32, 0f32, 0f32, 0f32, 218f32, 110f32]]]] } def mm2by2 (a1, b1, c1, d1) (a2, b2, c2, d2) : (f32,f32,f32,f32) = diff --git a/tests/ad/scan7.fut b/tests/ad/scan7.fut index a87a13cb55..3ebf9a2616 100644 --- a/tests/ad/scan7.fut +++ b/tests/ad/scan7.fut @@ -2,7 +2,7 @@ -- vectorised special case, generic case -- == -- entry: fwd_J rev_J --- compiled input { [[[1f32,2f32], [2f32,3f32]], [[4f32,5f32], [3f32,4f32]], +-- input { [[[1f32,2f32], [2f32,3f32]], [[4f32,5f32], [3f32,4f32]], -- [[3f32,4f32], [4f32,5f32]], [[4f32,5f32], [2f32,3f32]]] } -- output { --[[[[[[1f32, 0f32], [0f32, 0f32]], [[4f32, 0f32], [0f32, 0f32]], @@ -59,7 +59,7 @@ entry rev_J [n][m][k] (input: [n][m][k]f32) = -- == -- entry: test --- compiled input { [[[1f32,2f32], [2f32,3f32]], [[4f32,5f32], [3f32,4f32]], +-- input { [[[1f32,2f32], [2f32,3f32]], [[4f32,5f32], [3f32,4f32]], -- [[3f32,4f32], [4f32,5f32]], [[4f32,5f32], [2f32,3f32]]] -- [[[[[[1f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 1f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]], [[[[0f32, 0f32], [1f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 1f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]]], [[[[[0f32, 0f32], [0f32, 0f32]], [[1f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 1f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]], [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [1f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 1f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]]], [[[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[1f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 1f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]], [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [1f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 1f32]], [[0f32, 0f32], [0f32, 0f32]]]]], [[[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[1f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 1f32], [0f32, 0f32]]]], [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [1f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 1f32]]]]]] -- } diff --git a/tests/ad/scan8.fut b/tests/ad/scan8.fut index 47a7e75660..1eb490e708 100644 --- a/tests/ad/scan8.fut +++ b/tests/ad/scan8.fut @@ -1,7 +1,7 @@ -- Scan with 3x3 matrix multiplication. -- == -- entry: fwd rev --- compiled input { [[1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32], +-- input { [[1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32], -- [9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32], -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32], -- [9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32]] } diff --git a/tests/ad/scan9.fut b/tests/ad/scan9.fut index ea1da3de8a..771f00f367 100644 --- a/tests/ad/scan9.fut +++ b/tests/ad/scan9.fut @@ -1,7 +1,7 @@ -- Scan with 4x4 matrix multiplication. -- == -- entry: fwd rev --- compiled input { +-- no_oclgrind input { -- [[1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32,10f32,11f32,12f32,13f32,14f32,15f32,16f32], -- [16f32,15f32,14f32,13f32,12f32,11f32,10f32,9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32], -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32,10f32,11f32,12f32,13f32,14f32,15f32,16f32], diff --git a/tests/ad/scatter0.fut b/tests/ad/scatter0.fut index 6dbbed2f98..97e70aa347 100644 --- a/tests/ad/scatter0.fut +++ b/tests/ad/scatter0.fut @@ -1,7 +1,7 @@ -- Simple scatter, differentiating wrt. values. -- == -- entry: fwd rev --- compiled input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64, 2i64, 3i64] [1f64, 2f64, 3f64, 0f64] } +-- input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64, 2i64, 3i64] [1f64, 2f64, 3f64, 0f64] } -- output { -- [[1.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], -- [0.000000f64, 1.000000f64, 0.000000f64, 0.000000f64], diff --git a/tests/ad/scatter1.fut b/tests/ad/scatter1.fut index a323e25a6f..9de2b96a53 100644 --- a/tests/ad/scatter1.fut +++ b/tests/ad/scatter1.fut @@ -1,7 +1,7 @@ -- Simple scatter, differentiating wrt. target. -- == -- entry: fwd rev --- compiled input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64] [1f64, 2f64] } +-- input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64] [1f64, 2f64] } -- output { -- [[0.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], -- [0.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], diff --git a/tests/ad/sdf.fut b/tests/ad/sdf.fut index 0417c4f7a9..e8818ca174 100644 --- a/tests/ad/sdf.fut +++ b/tests/ad/sdf.fut @@ -1,9 +1,9 @@ -- Signed Distance Functions, as you would find in a ray marcher. -- == -- entry: jvp_normal vjp_normal --- compiled input { 0i32 0f64 1f64 0f64 } output { 1f64 0f64 0f64 } --- compiled input { 1i32 0f64 1f64 0f64 } output { 0.412393f64 0.907265f64 -0.082479f64 } --- compiled input { 2i32 0f64 1f64 0f64 } output { -0.375775f64 0.903687f64 -0.205287f64 } +-- input { 0i32 0f64 1f64 0f64 } output { 1f64 0f64 0f64 } +-- input { 1i32 0f64 1f64 0f64 } output { 0.412393f64 0.907265f64 -0.082479f64 } +-- input { 2i32 0f64 1f64 0f64 } output { -0.375775f64 0.903687f64 -0.205287f64 } type Vec = {x:f64, y: f64, z: f64} diff --git a/tests/ad/stripmine0.fut b/tests/ad/stripmine0.fut index caac5224a9..271a426304 100644 --- a/tests/ad/stripmine0.fut +++ b/tests/ad/stripmine0.fut @@ -1,11 +1,11 @@ def pow y x = #[stripmine(3)] loop acc = 1 for _i < y do - acc * x + acc * x -- == -- entry: f_jvp f_vjp --- compiled input { 3 4 } output { 48 } --- compiled input { 9 3 } output { 59049 } +-- input { 3 4 } output { 48 } +-- input { 9 3 } output { 59049 } -- compiled input { 1000000 1 } output { 1000000 } entry f_jvp y x = jvp (pow y) x 1 entry f_vjp y x = vjp (pow y) x 1 diff --git a/tests/ad/stripmine1.fut b/tests/ad/stripmine1.fut index e13041c2ca..e31f4cccfb 100644 --- a/tests/ad/stripmine1.fut +++ b/tests/ad/stripmine1.fut @@ -7,12 +7,12 @@ def square [n] (xs: [n]i32) = -- == -- entry: prim --- compiled input { [1,2,3,4,5] } output { [1,4,9,16,25] } +-- input { [1,2,3,4,5] } output { [1,4,9,16,25] } entry prim [n] (xs: [n]i32) = square xs -- == -- entry: f_jvp f_vjp --- compiled input { [1,2,3,4,5] } +-- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], -- [0,0,6,0,0], diff --git a/tests/ad/stripmine2.fut b/tests/ad/stripmine2.fut index f82f07b975..1e654969d2 100644 --- a/tests/ad/stripmine2.fut +++ b/tests/ad/stripmine2.fut @@ -5,12 +5,12 @@ def pow_list [n] y (xs :[n]i32) = -- == -- entry: prim --- compiled input { 3 [1,2,3] } output { [1,8,27] } +-- input { 3 [1,2,3] } output { [1,8,27] } entry prim y xs = pow_list y xs -- == -- entry: f_vjp f_jvp --- compiled input { 3 [1,2,3] } +-- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], -- [0,0,27]] diff --git a/tests/ad/stripmine3.fut b/tests/ad/stripmine3.fut index 08135a4498..e57c96578e 100644 --- a/tests/ad/stripmine3.fut +++ b/tests/ad/stripmine3.fut @@ -8,10 +8,10 @@ def test [n] (xs: [n]i32) = -- == -- entry: prim --- compiled input { [1,2,3,4,5] } output { [1,1,1,1,1] } +-- input { [1,2,3,4,5] } output { [1,1,1,1,1] } entry prim [n] (xs: [n]i32) = test xs -- == -- entry: f_vjp --- compiled input { [1,2,3,4,5] } output { [0,0,0,0,0] } +-- input { [1,2,3,4,5] } output { [0,0,0,0,0] } entry f_vjp [n] (xs: [n]i32) = vjp test xs (replicate n 1) diff --git a/tests/ad/sum.fut b/tests/ad/sum.fut index e2d3861a86..25e71c4c67 100644 --- a/tests/ad/sum.fut +++ b/tests/ad/sum.fut @@ -1,7 +1,7 @@ -- Simple reduce with summation. -- == -- entry: rev fwd --- compiled input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] } +-- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] } -- output { [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] } def sum [n] (xs: [n]f64) = diff --git a/tests/ad/truedep0.fut b/tests/ad/truedep0.fut index 90149139b5..5465d50a95 100644 --- a/tests/ad/truedep0.fut +++ b/tests/ad/truedep0.fut @@ -4,12 +4,12 @@ def test [n] (xs: [n]i32) = -- == -- entry: prim --- compiled input { [2,2,3,4,5] } output { [2,4,16,256,65536] } +-- input { [2,2,3,4,5] } output { [2,4,16,256,65536] } entry prim [n] (xs: [n]i32) = test xs -- == -- entry: f_jvp f_vjp --- compiled input { [1,2,3,4,5] } +-- input { [1,2,3,4,5] } -- output { [[1,0,0,0,0], -- [2,0,0,0,0], -- [4,0,0,0,0], diff --git a/tests/ad/while0.fut b/tests/ad/while0.fut index 66942e9bbb..90fdcb2908 100644 --- a/tests/ad/while0.fut +++ b/tests/ad/while0.fut @@ -7,13 +7,13 @@ def pow y x = -- == -- entry: prim --- compiled input { 3 4 } output { 64 } --- compiled input { 9 3 } output { 19683 } +-- input { 3 4 } output { 64 } +-- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp f_vjp --- compiled input { 3 4 } output { 48 } --- compiled input { 9 3 } output { 59049 } +-- input { 3 4 } output { 48 } +-- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 entry f_vjp y x = vjp (pow y) x 1 diff --git a/tests/ad/while1.fut b/tests/ad/while1.fut index 66005a0a2f..d02831126c 100644 --- a/tests/ad/while1.fut +++ b/tests/ad/while1.fut @@ -6,13 +6,13 @@ def pow y x = -- == -- entry: prim --- compiled input { 3 4 } output { 64 } --- compiled input { 9 3 } output { 19683 } +-- input { 3 4 } output { 64 } +-- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp f_vjp --- compiled input { 3 4 } output { 48 } --- compiled input { 9 3 } output { 59049 } +-- input { 3 4 } output { 48 } +-- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 entry f_vjp y x = vjp (pow y) x 1