From 7a69912a02f4d85ee94b39c9af4859ff2003d5a9 Mon Sep 17 00:00:00 2001 From: Hugo Date: Wed, 27 Oct 2021 20:38:47 +0200 Subject: [PATCH 1/3] Implement the vector indexing operation --- .../Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs | 2 +- .../Array/Accelerate/LLVM/CodeGen/Arithmetic.hs | 13 +++++++++++++ .../src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs | 10 +++++----- .../Data/Array/Accelerate/LLVM/CodeGen/Constant.hs | 2 +- .../src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs | 8 +++++--- accelerate-llvm/src/LLVM/AST/Type/Instruction.hs | 7 ++++--- accelerate-llvm/src/LLVM/AST/Type/Operand.hs | 4 ++++ 7 files changed, 33 insertions(+), 13 deletions(-) diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs index 7c414076b..5955334b8 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs @@ -445,7 +445,7 @@ shfl sop tR val delta = go tR val repack :: Int32 -> CodeGen PTX (Operands (Vec m Int32)) repack 0 = return $ ir v' (A.undef (VectorScalarType v')) repack i = do - d <- instr $ ExtractElement (i-1) c + d <- instr $ ExtractElement integralType c (constOp (i-1)) e <- integral integralType d f <- repack (i-1) g <- instr $ InsertElement (i-1) (op v' f) (op integralType e) diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs index 48555c94a..47ab72c60 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs @@ -20,6 +20,8 @@ module Data.Array.Accelerate.LLVM.CodeGen.Arithmetic where +import Data.Primitive.Vec + import Data.Array.Accelerate.AST ( PrimMaybe ) import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Representation.Tag @@ -464,6 +466,17 @@ min ty x y | otherwise = do c <- unbool <$> lte ty x y binop (flip Select c) ty x y +-- Vector operators +-- ---------------------- + +vecCreate :: VectorType (Vec n a) -> CodeGen arch (Operands (Vec n a)) +vecCreate = undefined + +vecIndex :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> CodeGen arch (Operands a) +vecIndex tv ti (OP_Vec v) i = do + (OP_Int32 i') <- fromIntegral ti (IntegralNumType TypeInt32) i + instr $ ExtractElement TypeInt32 v i' + -- Logical operators -- ----------------- diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs index ea984d21c..049960857 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs @@ -28,6 +28,7 @@ import LLVM.AST.Type.AddrSpace import LLVM.AST.Type.Instruction import LLVM.AST.Type.Instruction.Volatile import LLVM.AST.Type.Operand +import LLVM.AST.Type.Constant import LLVM.AST.Type.Representation import Data.Array.Accelerate.Representation.Array @@ -205,16 +206,15 @@ store addrspace volatility e p v | SingleScalarType{} <- e = do_ $ Store volatility p v | VectorScalarType s <- e , VectorType n base <- s - , m <- fromIntegral n - = if popCount m == 1 + = if popCount n == 1 then do_ $ Store volatility p v else do p' <- instr' $ PtrCast (PtrPrimType (ScalarPrimType (SingleScalarType base)) addrspace) p -- - let go i - | i >= m = return () + let go i + | i >= n = return () | otherwise = do - x <- instr' $ ExtractElement i v + x <- instr' $ ExtractElement integralType v (constOp n) q <- instr' $ GetElementPtr p' [integral integralType i] _ <- instr' $ Store volatility q x go (i+1) diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Constant.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Constant.hs index 26c9497ed..962f1d86f 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Constant.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Constant.hs @@ -61,7 +61,7 @@ scalar t = ConstantOperand . ScalarConstant t single :: SingleType a -> a -> Operand a single t = scalar (SingleScalarType t) -vector :: VectorType (Vec n a) -> (Vec n a) -> Operand (Vec n a) +vector :: VectorType (Vec n a) -> Vec n a -> Operand (Vec n a) vector t = scalar (VectorScalarType t) num :: NumType a -> a -> Operand a diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs index 77a2f0860..95149b14b 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs @@ -48,7 +48,8 @@ import qualified Data.Array.Accelerate.LLVM.CodeGen.Loop as L import Data.Primitive.Vec import LLVM.AST.Type.Instruction -import LLVM.AST.Type.Operand ( Operand ) +import LLVM.AST.Type.Operand ( Operand(..), constOp) +import LLVM.AST.Type.Constant ( Constant(..), ) import Control.Applicative hiding ( Const ) import Control.Monad @@ -105,7 +106,7 @@ llvmOfOpenExp top env aenv = cvtE top llvmOfOpenExp body (env `pushE` (lhs, x)) aenv Evar (Var _ ix) -> return $ prj ix env Const tp c -> return $ ir tp $ scalar tp c - PrimConst c -> let tp = (SingleScalarType $ primConstType c) + PrimConst c -> let tp = primConstType c in return $ ir tp $ scalar tp $ primConst c PrimApp f x -> primFun f x Undef tp -> return $ ir tp $ undef tp @@ -165,7 +166,7 @@ llvmOfOpenExp top env aenv = cvtE top go (VecRnil _) _ = internalError "index mismatch" go (VecRsucc vecr') i = do xs <- go vecr' (i - 1) - x <- instr' $ ExtractElement (fromIntegral i - 1) vec + x <- instr' $ ExtractElement TypeInt vec (constOp (i - 1)) return $ OP_Pair xs (ir singleTp x) singleTp :: SingleType single -- GHC 8.4 cannot infer this type for some reason @@ -307,6 +308,7 @@ llvmOfOpenExp top env aenv = cvtE top PrimEq t -> primbool $ A.uncurry (A.eq t) =<< cvtE x PrimNEq t -> primbool $ A.uncurry (A.neq t) =<< cvtE x PrimLNot -> primbool $ A.lnot =<< bool (cvtE x) + PrimVectorIndex v i -> A.uncurry (A.vecIndex v i) =<< cvtE x -- no missing patterns, whoo! diff --git a/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs b/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs index 4566ab1d2..b54df8486 100644 --- a/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs +++ b/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs @@ -182,8 +182,9 @@ data Instruction a where -- -- - ExtractElement :: Int32 -- TupleIdx (ProdRepr (Vec n a)) a + ExtractElement :: IntegralType i -- TupleIdx (ProdRepr (Vec n a)) a -> Operand (Vec n a) + -> Operand i -> Instruction a -- @@ -406,7 +407,7 @@ instance Downcast (Instruction a) LLVM.Instruction where BXor _ x y -> LLVM.Xor (downcast x) (downcast y) md LNot x -> LLVM.Xor (downcast x) (LLVM.ConstantOperand (LLVM.Int 1 1)) md InsertElement i v x -> LLVM.InsertElement (downcast v) (downcast x) (constant i) md - ExtractElement i v -> LLVM.ExtractElement (downcast v) (constant i) md + ExtractElement _ v i -> LLVM.ExtractElement (downcast v) (downcast i) md ExtractValue _ i s -> extractStruct i (downcast s) Load _ v p -> LLVM.Load (downcast v) (downcast p) atomicity alignment md Store v p x -> LLVM.Store (downcast v) (downcast p) (downcast x) atomicity alignment md @@ -594,7 +595,7 @@ instance TypeOf Instruction where LAnd x _ -> typeOf x LOr x _ -> typeOf x LNot x -> typeOf x - ExtractElement _ x -> typeOfVec x + ExtractElement _ x _ -> typeOfVec x InsertElement _ x _ -> typeOf x ExtractValue t _ _ -> scalar t Load t _ _ -> scalar t diff --git a/accelerate-llvm/src/LLVM/AST/Type/Operand.hs b/accelerate-llvm/src/LLVM/AST/Type/Operand.hs index 9bd3ab21e..f2e73b8c0 100644 --- a/accelerate-llvm/src/LLVM/AST/Type/Operand.hs +++ b/accelerate-llvm/src/LLVM/AST/Type/Operand.hs @@ -15,6 +15,7 @@ module LLVM.AST.Type.Operand ( Operand(..), + constOp, ) where @@ -32,6 +33,9 @@ data Operand a where LocalReference :: Type a -> Name a -> Operand a ConstantOperand :: Constant a -> Operand a +constOp :: (IsScalar a) => a -> Operand a +constOp x = ConstantOperand (ScalarConstant scalarType x) + -- | Convert to llvm-hs -- From 29726bd137525def46f3d88680064b43d70965f1 Mon Sep 17 00:00:00 2001 From: Hugo Date: Tue, 2 Nov 2021 12:25:16 +0100 Subject: [PATCH 2/3] implement vector operations --- .../Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs | 2 +- .../Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs | 11 ++++++++--- .../src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs | 9 ++++----- .../src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs | 3 ++- accelerate-llvm/src/LLVM/AST/Type/Instruction.hs | 7 ++++--- 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs index 5955334b8..3b5e9e491 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs @@ -448,7 +448,7 @@ shfl sop tR val delta = go tR val d <- instr $ ExtractElement integralType c (constOp (i-1)) e <- integral integralType d f <- repack (i-1) - g <- instr $ InsertElement (i-1) (op v' f) (op integralType e) + g <- instr $ InsertElement integralType (op v' f) (constOp (i-1)) (op integralType e) return g h <- repack (P.fromIntegral m) diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs index 47ab72c60..80c925eca 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs @@ -473,9 +473,11 @@ vecCreate :: VectorType (Vec n a) -> CodeGen arch (Operands (Vec n a)) vecCreate = undefined vecIndex :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> CodeGen arch (Operands a) -vecIndex tv ti (OP_Vec v) i = do - (OP_Int32 i') <- fromIntegral ti (IntegralNumType TypeInt32) i - instr $ ExtractElement TypeInt32 v i' +vecIndex tv ti (op tv -> v) (op ti -> i) = instr $ ExtractElement ti v i + +vecWrite :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> Operands a -> CodeGen arch (Operands (Vec n a)) +vecWrite tv@(VectorType _ ta) ti (op tv -> v) (op ti -> i) (op ta -> val) = instr $ InsertElement ti v i val + -- Logical operators @@ -570,6 +572,9 @@ unpair (OP_Pair x y) = (x, y) uncurry :: (Operands a -> Operands b -> c) -> Operands (a, b) -> c uncurry f (OP_Pair x y) = f x y +uncurry3 :: (Operands a -> Operands b -> Operands c -> d) -> Operands (a, (b, c)) -> d +uncurry3 f (OP_Pair x (OP_Pair y z)) = f x y z + unbool :: Operands Bool -> Operand Bool unbool (OP_Bool x) = x diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs index 049960857..a4fb10f27 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs @@ -174,18 +174,17 @@ load addrspace e v p | SingleScalarType{} <- e = instr' $ Load e v p | VectorScalarType s <- e , VectorType n base <- s - , m <- fromIntegral n - = if popCount m == 1 + = if popCount n == 1 then instr' $ Load e v p else do p' <- instr' $ PtrCast (PtrPrimType (ScalarPrimType (SingleScalarType base)) addrspace) p -- let go i w - | i >= m = return w + | i >= n = return w | otherwise = do q <- instr' $ GetElementPtr p' [integral integralType i] r <- instr' $ Load (SingleScalarType base) v q - w' <- instr' $ InsertElement i w r + w' <- instr' $ InsertElement integralType w (constOp i) r go (i+1) w' -- go 0 (undef e) @@ -214,7 +213,7 @@ store addrspace volatility e p v let go i | i >= n = return () | otherwise = do - x <- instr' $ ExtractElement integralType v (constOp n) + x <- instr' $ ExtractElement integralType v (constOp i) q <- instr' $ GetElementPtr p' [integral integralType i] _ <- instr' $ Store volatility q x go (i+1) diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs index 95149b14b..1dad5ffb1 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs @@ -153,7 +153,7 @@ llvmOfOpenExp top env aenv = cvtE top go (VecRnil _) _ OP_Unit = internalError "index mismatch" go (VecRsucc vecr') i (OP_Pair xs x) = do vec <- go vecr' (i - 1) xs - instr' $ InsertElement (fromIntegral i - 1) vec (op singleTp x) + instr' $ InsertElement integralType vec (constOp (i - 1)) (op singleTp x) singleTp :: SingleType single -- GHC 8.4 cannot infer this type for some reason tp@(VectorType n singleTp) = vecRvector vecr @@ -309,6 +309,7 @@ llvmOfOpenExp top env aenv = cvtE top PrimNEq t -> primbool $ A.uncurry (A.neq t) =<< cvtE x PrimLNot -> primbool $ A.lnot =<< bool (cvtE x) PrimVectorIndex v i -> A.uncurry (A.vecIndex v i) =<< cvtE x + PrimVectorWrite v i -> A.uncurry3 (A.vecWrite v i) =<< cvtE x -- no missing patterns, whoo! diff --git a/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs b/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs index b54df8486..eedea814e 100644 --- a/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs +++ b/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs @@ -189,8 +189,9 @@ data Instruction a where -- -- - InsertElement :: Int32 -- TupleIdx (ProdRepr (Vec n a)) a + InsertElement :: IntegralType i -- TupleIdx (ProdRepr (Vec n a)) a -> Operand (Vec n a) + -> Operand i -> Operand a -> Instruction (Vec n a) @@ -406,7 +407,7 @@ instance Downcast (Instruction a) LLVM.Instruction where LOr x y -> LLVM.Or (downcast x) (downcast y) md BXor _ x y -> LLVM.Xor (downcast x) (downcast y) md LNot x -> LLVM.Xor (downcast x) (LLVM.ConstantOperand (LLVM.Int 1 1)) md - InsertElement i v x -> LLVM.InsertElement (downcast v) (downcast x) (constant i) md + InsertElement _ v i x -> LLVM.InsertElement (downcast v) (downcast x) (downcast i) md ExtractElement _ v i -> LLVM.ExtractElement (downcast v) (downcast i) md ExtractValue _ i s -> extractStruct i (downcast s) Load _ v p -> LLVM.Load (downcast v) (downcast p) atomicity alignment md @@ -596,7 +597,7 @@ instance TypeOf Instruction where LOr x _ -> typeOf x LNot x -> typeOf x ExtractElement _ x _ -> typeOfVec x - InsertElement _ x _ -> typeOf x + InsertElement _ x _ _ -> typeOf x ExtractValue t _ _ -> scalar t Load t _ _ -> scalar t Store{} -> VoidType From ef0172f8891a8f0e7ae9ce76028848c067dd6d44 Mon Sep 17 00:00:00 2001 From: Hugo Date: Thu, 2 Dec 2021 16:04:45 +0100 Subject: [PATCH 3/3] Move vec operations to correct AST --- .../Accelerate/LLVM/CodeGen/Arithmetic.hs | 5 --- .../Data/Array/Accelerate/LLVM/CodeGen/Exp.hs | 15 ++++++-- .../src/Data/Array/Accelerate/LLVM/Compile.hs | 36 ++++++++++--------- 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs index 80c925eca..218992b1c 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs @@ -472,11 +472,6 @@ min ty x y vecCreate :: VectorType (Vec n a) -> CodeGen arch (Operands (Vec n a)) vecCreate = undefined -vecIndex :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> CodeGen arch (Operands a) -vecIndex tv ti (op tv -> v) (op ti -> i) = instr $ ExtractElement ti v i - -vecWrite :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> Operands a -> CodeGen arch (Operands (Vec n a)) -vecWrite tv@(VectorType _ ta) ti (op tv -> v) (op ti -> i) (op ta -> val) = instr $ InsertElement ti v i val diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs index 1dad5ffb1..dd8efde3a 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs @@ -114,6 +114,13 @@ llvmOfOpenExp top env aenv = cvtE top Pair e1 e2 -> join $ pair <$> cvtE e1 <*> cvtE e2 VecPack vecr e -> vecPack vecr =<< cvtE e VecUnpack vecr e -> vecUnpack vecr =<< cvtE e + VecIndex vt ti v i -> do v' <- cvtE v + i' <- cvtE i + vecIndexGen vt ti v' i' + VecWrite vt ti v i e -> do v' <- cvtE v + i' <- cvtE i + e' <- cvtE e + vecWriteGen vt ti v' i' e' Foreign tp asm f x -> foreignE tp asm f =<< cvtE x Case tag xs mx -> A.caseof (expType (snd (head xs))) (cvtE tag) [(t,cvtE e) | (t,e) <- xs] (fmap cvtE mx) Cond c t e -> cond (expType t) (cvtE c) (cvtE t) (cvtE e) @@ -172,6 +179,12 @@ llvmOfOpenExp top env aenv = cvtE top singleTp :: SingleType single -- GHC 8.4 cannot infer this type for some reason VectorType n singleTp = vecRvector vecr + vecIndexGen :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> CodeGen arch (Operands a) + vecIndexGen tv ti (op tv -> v) (op ti -> i) = instr $ ExtractElement ti v i + + vecWriteGen :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> Operands a -> CodeGen arch (Operands (Vec n a)) + vecWriteGen tv@(VectorType _ ts) ti (op tv -> v) (op ti -> i) (op ts -> e) = instr $ InsertElement ti v i e + linearIndex :: ArrayVar aenv (Array sh e) -> Operands Int -> IROpenExp arch env aenv e linearIndex (Var repr v) = linearIndexArray (irArray repr (aprj v aenv)) @@ -308,8 +321,6 @@ llvmOfOpenExp top env aenv = cvtE top PrimEq t -> primbool $ A.uncurry (A.eq t) =<< cvtE x PrimNEq t -> primbool $ A.uncurry (A.neq t) =<< cvtE x PrimLNot -> primbool $ A.lnot =<< bool (cvtE x) - PrimVectorIndex v i -> A.uncurry (A.vecIndex v i) =<< cvtE x - PrimVectorWrite v i -> A.uncurry3 (A.vecWrite v i) =<< cvtE x -- no missing patterns, whoo! diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/Compile.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/Compile.hs index fee50f538..e84404734 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/Compile.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/Compile.hs @@ -368,24 +368,26 @@ compileOpenAcc = traverseAcc Undef tp -> return $ pure $ Undef tp Foreign tp ff f x -> foreignE tp ff f x -- - Let lhs a b -> liftA2 (Let lhs) <$> travE a <*> travE b - IndexSlice slix x s -> liftA2 (IndexSlice slix) <$> travE x <*> travE s - IndexFull slix x s -> liftA2 (IndexFull slix) <$> travE x <*> travE s - ToIndex shr s i -> liftA2 (ToIndex shr) <$> travE s <*> travE i - FromIndex shr s i -> liftA2 (FromIndex shr) <$> travE s <*> travE i + Let lhs a b -> liftA2 (Let lhs) <$> travE a <*> travE b + IndexSlice slix x s -> liftA2 (IndexSlice slix) <$> travE x <*> travE s + IndexFull slix x s -> liftA2 (IndexFull slix) <$> travE x <*> travE s + ToIndex shr s i -> liftA2 (ToIndex shr) <$> travE s <*> travE i + FromIndex shr s i -> liftA2 (FromIndex shr) <$> travE s <*> travE i Nil -> return $ pure Nil - Pair e1 e2 -> liftA2 Pair <$> travE e1 <*> travE e2 - VecPack vecr e -> liftA (VecPack vecr) <$> travE e - VecUnpack vecr e -> liftA (VecUnpack vecr) <$> travE e - Case t xs x -> liftA3 Case <$> travE t <*> travLE xs <*> travME x - Cond p t e -> liftA3 Cond <$> travE p <*> travE t <*> travE e - While p f x -> liftA3 While <$> travF p <*> travF f <*> travE x - PrimApp f e -> liftA (PrimApp f) <$> travE e - Index a e -> liftA2 Index <$> travA a <*> travE e - LinearIndex a e -> liftA2 LinearIndex <$> travA a <*> travE e - Shape a -> liftA Shape <$> travA a - ShapeSize shr e -> liftA (ShapeSize shr) <$> travE e - Coerce t1 t2 x -> liftA (Coerce t1 t2) <$> travE x + Pair e1 e2 -> liftA2 Pair <$> travE e1 <*> travE e2 + VecPack vecr e -> liftA (VecPack vecr) <$> travE e + VecUnpack vecr e -> liftA (VecUnpack vecr) <$> travE e + VecIndex vt it v i -> liftA2 (VecIndex vt it) <$> travE v <*> travE i + VecWrite vt it v i e -> liftA3 (VecWrite vt it) <$> travE v <*> travE i <*> travE e + Case t xs x -> liftA3 Case <$> travE t <*> travLE xs <*> travME x + Cond p t e -> liftA3 Cond <$> travE p <*> travE t <*> travE e + While p f x -> liftA3 While <$> travF p <*> travF f <*> travE x + PrimApp f e -> liftA (PrimApp f) <$> travE e + Index a e -> liftA2 Index <$> travA a <*> travE e + LinearIndex a e -> liftA2 LinearIndex <$> travA a <*> travE e + Shape a -> liftA Shape <$> travA a + ShapeSize shr e -> liftA (ShapeSize shr) <$> travE e + Coerce t1 t2 x -> liftA (Coerce t1 t2) <$> travE x where travA :: ArrayVar aenv (Array sh e)