Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the vector indexing operation #75

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,10 @@ shfl sop tR val delta = go tR val
repack :: Int32 -> CodeGen PTX (Operands (Vec m Int32))
repack 0 = return $ ir v' (A.undef (VectorScalarType v'))
repack i = do
d <- instr $ ExtractElement (i-1) c
d <- instr $ ExtractElement integralType c (constOp (i-1))
e <- integral integralType d
f <- repack (i-1)
g <- instr $ InsertElement (i-1) (op v' f) (op integralType e)
g <- instr $ InsertElement integralType (op v' f) (constOp (i-1)) (op integralType e)
return g

h <- repack (P.fromIntegral m)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
module Data.Array.Accelerate.LLVM.CodeGen.Arithmetic
where

import Data.Primitive.Vec

import Data.Array.Accelerate.AST ( PrimMaybe )
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Representation.Tag
Expand Down Expand Up @@ -464,6 +466,14 @@ min ty x y
| otherwise = do c <- unbool <$> lte ty x y
binop (flip Select c) ty x y

-- Vector operators
-- ----------------------

vecCreate :: VectorType (Vec n a) -> CodeGen arch (Operands (Vec n a))
vecCreate = undefined




-- Logical operators
-- -----------------
Expand Down Expand Up @@ -557,6 +567,9 @@ unpair (OP_Pair x y) = (x, y)
uncurry :: (Operands a -> Operands b -> c) -> Operands (a, b) -> c
uncurry f (OP_Pair x y) = f x y

uncurry3 :: (Operands a -> Operands b -> Operands c -> d) -> Operands (a, (b, c)) -> d
uncurry3 f (OP_Pair x (OP_Pair y z)) = f x y z

unbool :: Operands Bool -> Operand Bool
unbool (OP_Bool x) = x

Expand Down
17 changes: 8 additions & 9 deletions accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Constant
import LLVM.AST.Type.Representation

import Data.Array.Accelerate.Representation.Array
Expand Down Expand Up @@ -173,18 +174,17 @@ load addrspace e v p
| SingleScalarType{} <- e = instr' $ Load e v p
| VectorScalarType s <- e
, VectorType n base <- s
, m <- fromIntegral n
= if popCount m == 1
= if popCount n == 1
then instr' $ Load e v p
else do
p' <- instr' $ PtrCast (PtrPrimType (ScalarPrimType (SingleScalarType base)) addrspace) p
--
let go i w
| i >= m = return w
| i >= n = return w
| otherwise = do
q <- instr' $ GetElementPtr p' [integral integralType i]
r <- instr' $ Load (SingleScalarType base) v q
w' <- instr' $ InsertElement i w r
w' <- instr' $ InsertElement integralType w (constOp i) r
go (i+1) w'
--
go 0 (undef e)
Expand All @@ -205,16 +205,15 @@ store addrspace volatility e p v
| SingleScalarType{} <- e = do_ $ Store volatility p v
| VectorScalarType s <- e
, VectorType n base <- s
, m <- fromIntegral n
= if popCount m == 1
= if popCount n == 1
then do_ $ Store volatility p v
else do
p' <- instr' $ PtrCast (PtrPrimType (ScalarPrimType (SingleScalarType base)) addrspace) p
--
let go i
| i >= m = return ()
let go i
| i >= n = return ()
| otherwise = do
x <- instr' $ ExtractElement i v
x <- instr' $ ExtractElement integralType v (constOp i)
q <- instr' $ GetElementPtr p' [integral integralType i]
_ <- instr' $ Store volatility q x
go (i+1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ scalar t = ConstantOperand . ScalarConstant t
single :: SingleType a -> a -> Operand a
single t = scalar (SingleScalarType t)

vector :: VectorType (Vec n a) -> (Vec n a) -> Operand (Vec n a)
vector :: VectorType (Vec n a) -> Vec n a -> Operand (Vec n a)
vector t = scalar (VectorScalarType t)

num :: NumType a -> a -> Operand a
Expand Down
22 changes: 18 additions & 4 deletions accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ import qualified Data.Array.Accelerate.LLVM.CodeGen.Loop as L
import Data.Primitive.Vec

import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Operand ( Operand )
import LLVM.AST.Type.Operand ( Operand(..), constOp)
import LLVM.AST.Type.Constant ( Constant(..), )

import Control.Applicative hiding ( Const )
import Control.Monad
Expand Down Expand Up @@ -105,14 +106,21 @@ llvmOfOpenExp top env aenv = cvtE top
llvmOfOpenExp body (env `pushE` (lhs, x)) aenv
Evar (Var _ ix) -> return $ prj ix env
Const tp c -> return $ ir tp $ scalar tp c
PrimConst c -> let tp = (SingleScalarType $ primConstType c)
PrimConst c -> let tp = primConstType c
in return $ ir tp $ scalar tp $ primConst c
PrimApp f x -> primFun f x
Undef tp -> return $ ir tp $ undef tp
Nil -> return $ OP_Unit
Pair e1 e2 -> join $ pair <$> cvtE e1 <*> cvtE e2
VecPack vecr e -> vecPack vecr =<< cvtE e
VecUnpack vecr e -> vecUnpack vecr =<< cvtE e
VecIndex vt ti v i -> do v' <- cvtE v
i' <- cvtE i
vecIndexGen vt ti v' i'
VecWrite vt ti v i e -> do v' <- cvtE v
i' <- cvtE i
e' <- cvtE e
vecWriteGen vt ti v' i' e'
Foreign tp asm f x -> foreignE tp asm f =<< cvtE x
Case tag xs mx -> A.caseof (expType (snd (head xs))) (cvtE tag) [(t,cvtE e) | (t,e) <- xs] (fmap cvtE mx)
Cond c t e -> cond (expType t) (cvtE c) (cvtE t) (cvtE e)
Expand Down Expand Up @@ -152,7 +160,7 @@ llvmOfOpenExp top env aenv = cvtE top
go (VecRnil _) _ OP_Unit = internalError "index mismatch"
go (VecRsucc vecr') i (OP_Pair xs x) = do
vec <- go vecr' (i - 1) xs
instr' $ InsertElement (fromIntegral i - 1) vec (op singleTp x)
instr' $ InsertElement integralType vec (constOp (i - 1)) (op singleTp x)

singleTp :: SingleType single -- GHC 8.4 cannot infer this type for some reason
tp@(VectorType n singleTp) = vecRvector vecr
Expand All @@ -165,12 +173,18 @@ llvmOfOpenExp top env aenv = cvtE top
go (VecRnil _) _ = internalError "index mismatch"
go (VecRsucc vecr') i = do
xs <- go vecr' (i - 1)
x <- instr' $ ExtractElement (fromIntegral i - 1) vec
x <- instr' $ ExtractElement TypeInt vec (constOp (i - 1))
return $ OP_Pair xs (ir singleTp x)

singleTp :: SingleType single -- GHC 8.4 cannot infer this type for some reason
VectorType n singleTp = vecRvector vecr

vecIndexGen :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> CodeGen arch (Operands a)
vecIndexGen tv ti (op tv -> v) (op ti -> i) = instr $ ExtractElement ti v i

vecWriteGen :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> Operands a -> CodeGen arch (Operands (Vec n a))
vecWriteGen tv@(VectorType _ ts) ti (op tv -> v) (op ti -> i) (op ts -> e) = instr $ InsertElement ti v i e

linearIndex :: ArrayVar aenv (Array sh e) -> Operands Int -> IROpenExp arch env aenv e
linearIndex (Var repr v) = linearIndexArray (irArray repr (aprj v aenv))

Expand Down
36 changes: 19 additions & 17 deletions accelerate-llvm/src/Data/Array/Accelerate/LLVM/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -368,24 +368,26 @@ compileOpenAcc = traverseAcc
Undef tp -> return $ pure $ Undef tp
Foreign tp ff f x -> foreignE tp ff f x
--
Let lhs a b -> liftA2 (Let lhs) <$> travE a <*> travE b
IndexSlice slix x s -> liftA2 (IndexSlice slix) <$> travE x <*> travE s
IndexFull slix x s -> liftA2 (IndexFull slix) <$> travE x <*> travE s
ToIndex shr s i -> liftA2 (ToIndex shr) <$> travE s <*> travE i
FromIndex shr s i -> liftA2 (FromIndex shr) <$> travE s <*> travE i
Let lhs a b -> liftA2 (Let lhs) <$> travE a <*> travE b
IndexSlice slix x s -> liftA2 (IndexSlice slix) <$> travE x <*> travE s
IndexFull slix x s -> liftA2 (IndexFull slix) <$> travE x <*> travE s
ToIndex shr s i -> liftA2 (ToIndex shr) <$> travE s <*> travE i
FromIndex shr s i -> liftA2 (FromIndex shr) <$> travE s <*> travE i
Nil -> return $ pure Nil
Pair e1 e2 -> liftA2 Pair <$> travE e1 <*> travE e2
VecPack vecr e -> liftA (VecPack vecr) <$> travE e
VecUnpack vecr e -> liftA (VecUnpack vecr) <$> travE e
Case t xs x -> liftA3 Case <$> travE t <*> travLE xs <*> travME x
Cond p t e -> liftA3 Cond <$> travE p <*> travE t <*> travE e
While p f x -> liftA3 While <$> travF p <*> travF f <*> travE x
PrimApp f e -> liftA (PrimApp f) <$> travE e
Index a e -> liftA2 Index <$> travA a <*> travE e
LinearIndex a e -> liftA2 LinearIndex <$> travA a <*> travE e
Shape a -> liftA Shape <$> travA a
ShapeSize shr e -> liftA (ShapeSize shr) <$> travE e
Coerce t1 t2 x -> liftA (Coerce t1 t2) <$> travE x
Pair e1 e2 -> liftA2 Pair <$> travE e1 <*> travE e2
VecPack vecr e -> liftA (VecPack vecr) <$> travE e
VecUnpack vecr e -> liftA (VecUnpack vecr) <$> travE e
VecIndex vt it v i -> liftA2 (VecIndex vt it) <$> travE v <*> travE i
VecWrite vt it v i e -> liftA3 (VecWrite vt it) <$> travE v <*> travE i <*> travE e
Case t xs x -> liftA3 Case <$> travE t <*> travLE xs <*> travME x
Cond p t e -> liftA3 Cond <$> travE p <*> travE t <*> travE e
While p f x -> liftA3 While <$> travF p <*> travF f <*> travE x
PrimApp f e -> liftA (PrimApp f) <$> travE e
Index a e -> liftA2 Index <$> travA a <*> travE e
LinearIndex a e -> liftA2 LinearIndex <$> travA a <*> travE e
Shape a -> liftA Shape <$> travA a
ShapeSize shr e -> liftA (ShapeSize shr) <$> travE e
Coerce t1 t2 x -> liftA (Coerce t1 t2) <$> travE x

where
travA :: ArrayVar aenv (Array sh e)
Expand Down
14 changes: 8 additions & 6 deletions accelerate-llvm/src/LLVM/AST/Type/Instruction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,16 @@ data Instruction a where

-- <http://llvm.org/docs/LangRef.html#extractelement-instruction>
--
ExtractElement :: Int32 -- TupleIdx (ProdRepr (Vec n a)) a
ExtractElement :: IntegralType i -- TupleIdx (ProdRepr (Vec n a)) a
-> Operand (Vec n a)
-> Operand i
-> Instruction a

-- <http://llvm.org/docs/LangRef.html#insertelement-instruction>
--
InsertElement :: Int32 -- TupleIdx (ProdRepr (Vec n a)) a
InsertElement :: IntegralType i -- TupleIdx (ProdRepr (Vec n a)) a
-> Operand (Vec n a)
-> Operand i
-> Operand a
-> Instruction (Vec n a)

Expand Down Expand Up @@ -405,8 +407,8 @@ instance Downcast (Instruction a) LLVM.Instruction where
LOr x y -> LLVM.Or (downcast x) (downcast y) md
BXor _ x y -> LLVM.Xor (downcast x) (downcast y) md
LNot x -> LLVM.Xor (downcast x) (LLVM.ConstantOperand (LLVM.Int 1 1)) md
InsertElement i v x -> LLVM.InsertElement (downcast v) (downcast x) (constant i) md
ExtractElement i v -> LLVM.ExtractElement (downcast v) (constant i) md
InsertElement _ v i x -> LLVM.InsertElement (downcast v) (downcast x) (downcast i) md
ExtractElement _ v i -> LLVM.ExtractElement (downcast v) (downcast i) md
ExtractValue _ i s -> extractStruct i (downcast s)
Load _ v p -> LLVM.Load (downcast v) (downcast p) atomicity alignment md
Store v p x -> LLVM.Store (downcast v) (downcast p) (downcast x) atomicity alignment md
Expand Down Expand Up @@ -594,8 +596,8 @@ instance TypeOf Instruction where
LAnd x _ -> typeOf x
LOr x _ -> typeOf x
LNot x -> typeOf x
ExtractElement _ x -> typeOfVec x
InsertElement _ x _ -> typeOf x
ExtractElement _ x _ -> typeOfVec x
InsertElement _ x _ _ -> typeOf x
ExtractValue t _ _ -> scalar t
Load t _ _ -> scalar t
Store{} -> VoidType
Expand Down
4 changes: 4 additions & 0 deletions accelerate-llvm/src/LLVM/AST/Type/Operand.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
module LLVM.AST.Type.Operand (

Operand(..),
constOp,

) where

Expand All @@ -32,6 +33,9 @@ data Operand a where
LocalReference :: Type a -> Name a -> Operand a
ConstantOperand :: Constant a -> Operand a

constOp :: (IsScalar a) => a -> Operand a
constOp x = ConstantOperand (ScalarConstant scalarType x)


-- | Convert to llvm-hs
--
Expand Down