From e6cd2ebb24144b31c3377facc896bb4b92e77289 Mon Sep 17 00:00:00 2001 From: Hugo Date: Tue, 19 Oct 2021 23:56:05 +0200 Subject: [PATCH 01/12] Vector indexing operations and empty vector constructor --- .gitignore | 2 ++ src/Data/Array/Accelerate.hs | 3 ++ src/Data/Array/Accelerate/AST.hs | 12 ++++++++ src/Data/Array/Accelerate/Classes/Enum.hs | 3 +- src/Data/Array/Accelerate/Classes/Vector.hs | 31 +++++++++++++++++++++ src/Data/Array/Accelerate/Smart.hs | 12 ++++++++ src/Data/Primitive/Vec.hs | 14 ++++++++++ 7 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 src/Data/Array/Accelerate/Classes/Vector.hs diff --git a/.gitignore b/.gitignore index 2dc9bad21..eec9590ea 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ /docs/_build *.hi *.o + +hie.yaml diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index ff1729f27..5654cd9f9 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -310,6 +310,7 @@ module Data.Array.Accelerate ( -- ** SIMD vectors Vec, VecElt, + mkVec, -- ** Type classes -- *** Basic type classes @@ -317,6 +318,7 @@ module Data.Array.Accelerate ( Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, Enum, succ, pred, Bounded, minBound, maxBound, + Vectoring(..), -- Functor(..), (<$>), ($>), void, -- Monad(..), @@ -445,6 +447,7 @@ import Data.Array.Accelerate.Classes.Rational import Data.Array.Accelerate.Classes.RealFloat import Data.Array.Accelerate.Classes.RealFrac import Data.Array.Accelerate.Classes.ToFloating +import Data.Array.Accelerate.Classes.Vector import Data.Array.Accelerate.Data.Either import Data.Array.Accelerate.Data.Maybe import Data.Array.Accelerate.Language diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index c84f5723f..0a887802f 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -748,6 +748,9 @@ data PrimFun sig where PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool) PrimLNot :: PrimFun (PrimBool -> PrimBool) + -- local array operators + PrimVectorIndex :: KnownNat n => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) + -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b) @@ -924,6 +927,12 @@ primFunType = \case PrimLOr -> binary' tbool PrimLNot -> unary' tbool +-- Local Vector operations + PrimVectorIndex v'@(VectorType _ a) i' -> + let v = singleVector v' + i = integral i' + in (v `TupRpair` i, single a) + -- general conversion between types PrimFromIntegral a b -> unary (integral a) (num b) PrimToFloating a b -> unary (num a) (floating b) @@ -936,6 +945,7 @@ primFunType = \case compare' a = binary (single a) tbool single = TupRsingle . SingleScalarType + singleVector = TupRsingle . VectorScalarType num = TupRsingle . SingleScalarType . NumSingleType integral = num . IntegralNumType floating = num . FloatingNumType @@ -1165,6 +1175,7 @@ rnfPrimFun (PrimMin t) = rnfSingleType t rnfPrimFun PrimLAnd = () rnfPrimFun PrimLOr = () rnfPrimFun PrimLNot = () +rnfPrimFun (PrimVectorIndex v i) = rnfVectorType v `seq` rnfIntegralType i rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f @@ -1391,6 +1402,7 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||] liftPrimFun PrimLAnd = [|| PrimLAnd ||] liftPrimFun PrimLOr = [|| PrimLOr ||] liftPrimFun PrimLNot = [|| PrimLNot ||] +liftPrimFun (PrimVectorIndex v i) = [||PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||] liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] diff --git a/src/Data/Array/Accelerate/Classes/Enum.hs b/src/Data/Array/Accelerate/Classes/Enum.hs index 84b344273..10e946ee5 100644 --- a/src/Data/Array/Accelerate/Classes/Enum.hs +++ b/src/Data/Array/Accelerate/Classes/Enum.hs @@ -187,8 +187,7 @@ defaultFromEnum = preludeError "fromEnum" preludeError :: String -> a preludeError x = error - $ unlines [ printf "Prelude.%s is not supported for Accelerate types" x - , "" + $ unlines [ printf "Prelude.%s is not supported for Accelerate types" x , "" , "These Prelude.Enum instances are present only to fulfil superclass" , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs new file mode 100644 index 000000000..69f62e7eb --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -0,0 +1,31 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} +-- | +-- Module : Data.Array.Accelerate.Classes.Vector +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- +module Data.Array.Accelerate.Classes.Vector where + +import GHC.TypeLits +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Smart +import Data.Primitive.Vec + +class Vectoring a b c | a -> b where + indexAt :: a -> c -> b + +instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) (Exp Int) where + indexAt = mkVectorIndex + + diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 8fa577f41..14c043d1f 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -12,6 +12,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE PolyKinds #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Smart @@ -71,6 +72,9 @@ module Data.Array.Accelerate.Smart ( -- ** Smart constructors for type coercion functions mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..), + -- ** Smart constructors for vector operations + mkVectorIndex, + -- ** Auxiliary functions ($$), ($$$), ($$$$), ($$$$$), ApplyAcc(..), @@ -83,6 +87,7 @@ module Data.Array.Accelerate.Smart ( ) where +import Data.Proxy import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Array @@ -95,6 +100,7 @@ import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Sugar.Array ( Arrays ) import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) ) import Data.Array.Accelerate.Type @@ -1172,6 +1178,12 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil where x = SmartExp $ Prj PairIdxLeft a +-- Operators from Vec +mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a +mkVectorIndex = let n :: Int + n = fromIntegral $ natVal $ Proxy @n + in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType + -- Numeric conversions mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 34a77635b..93b0395c0 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -10,6 +10,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec @@ -31,12 +32,16 @@ module Data.Primitive.Vec ( Vec8, pattern Vec8, Vec16, pattern Vec16, + mkVec, + listOfVec, liftVec, ) where +import Data.Proxy import Control.Monad.ST +import Control.Monad.Reader import Data.Primitive.ByteArray import Data.Primitive.Types import Data.Text.Prettyprint.Doc @@ -83,6 +88,14 @@ import GHC.Word -- data Vec (n :: Nat) a = Vec ByteArray# +mkVec :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a +mkVec vs = runST $ do + let n :: Int = fromIntegral $ natVal $ Proxy @n + mba <- newByteArray (n * sizeOf (undefined :: a)) + zipWithM_ (writeByteArray mba) [0..n] vs + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + type role Vec nominal representational instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where @@ -259,6 +272,7 @@ packVec16 a b c d e f g h i j k l m n o p = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# + -- O(n) at runtime to copy from the Addr# to the ByteArray#. We should be able -- to do this without copying, but I don't think the definition of ByteArray# is -- exported (or it is deeply magical). From 0c80d44a4c150d479d64af4c4b6d727eb6aa9d72 Mon Sep 17 00:00:00 2001 From: Hugo Date: Wed, 20 Oct 2021 16:46:35 +0200 Subject: [PATCH 02/12] created empty vector lifted in Exp --- src/Data/Array/Accelerate.hs | 4 +-- src/Data/Array/Accelerate/AST.hs | 35 +++++++++++++-------- src/Data/Array/Accelerate/Classes/Vector.hs | 27 ++++++++++------ src/Data/Array/Accelerate/Smart.hs | 8 ++++- src/Data/Primitive/Vec.hs | 2 -- 5 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 5654cd9f9..8811695b8 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -310,7 +310,7 @@ module Data.Array.Accelerate ( -- ** SIMD vectors Vec, VecElt, - mkVec, + Vectoring(..), -- ** Type classes -- *** Basic type classes @@ -318,7 +318,7 @@ module Data.Array.Accelerate ( Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, Enum, succ, pred, Bounded, minBound, maxBound, - Vectoring(..), + -- Functor(..), (<$>), ($>), void, -- Monad(..), diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 0a887802f..a07920466 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -655,6 +655,9 @@ data PrimConst ty where -- constant from Floating PrimPi :: FloatingType a -> PrimConst a + -- constant for empty Vec + PrimVectorCreate :: KnownNat n => VectorType (Vec n a) -> PrimConst (Vec n a) + -- |Primitive scalar operations -- @@ -828,7 +831,7 @@ expType = \case While _ (Lam lhs _) _ -> lhsToTupR lhs While{} -> error "What's the matter, you're running in the shadows" Const tR _ -> TupRsingle tR - PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c + PrimConst c -> TupRsingle $ primConstType c PrimApp f _ -> snd $ primFunType f Index (Var repr _) _ -> arrayRtype repr LinearIndex (Var repr _) _ -> arrayRtype repr @@ -837,17 +840,21 @@ expType = \case Undef tR -> TupRsingle tR Coerce _ tR _ -> TupRsingle tR -primConstType :: PrimConst a -> SingleType a +primConstType :: PrimConst a -> ScalarType a primConstType = \case PrimMinBound t -> bounded t PrimMaxBound t -> bounded t PrimPi t -> floating t + PrimVectorCreate t -> vector t where - bounded :: BoundedType a -> SingleType a - bounded (IntegralBoundedType t) = NumSingleType $ IntegralNumType t + bounded :: BoundedType a -> ScalarType a + bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t + + floating :: FloatingType t -> ScalarType t + floating = SingleScalarType . NumSingleType . FloatingNumType - floating :: FloatingType t -> SingleType t - floating = NumSingleType . FloatingNumType + vector :: forall n a. (KnownNat n) => VectorType (Vec n a) -> ScalarType (Vec n a) + vector = VectorScalarType primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b) primFunType = \case @@ -1110,9 +1117,10 @@ rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf = rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b rnfPrimConst :: PrimConst c -> () -rnfPrimConst (PrimMinBound t) = rnfBoundedType t -rnfPrimConst (PrimMaxBound t) = rnfBoundedType t -rnfPrimConst (PrimPi t) = rnfFloatingType t +rnfPrimConst (PrimMinBound t) = rnfBoundedType t +rnfPrimConst (PrimMaxBound t) = rnfBoundedType t +rnfPrimConst (PrimPi t) = rnfFloatingType t +rnfPrimConst (PrimVectorCreate t) = rnfVectorType t rnfPrimFun :: PrimFun f -> () rnfPrimFun (PrimAdd t) = rnfNumType t @@ -1337,9 +1345,10 @@ liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||] liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||] liftPrimConst :: PrimConst c -> CodeQ (PrimConst c) -liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] -liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] -liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] +liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] +liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] +liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] +liftPrimConst (PrimVectorCreate t) = [|| PrimVectorCreate $$(liftVectorType t) ||] liftPrimFun :: PrimFun f -> CodeQ (PrimFun f) liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||] @@ -1402,7 +1411,7 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||] liftPrimFun PrimLAnd = [|| PrimLAnd ||] liftPrimFun PrimLOr = [|| PrimLOr ||] liftPrimFun PrimLNot = [|| PrimLNot ||] -liftPrimFun (PrimVectorIndex v i) = [||PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||] +liftPrimFun (PrimVectorIndex v i) = [|| PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||] liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 69f62e7eb..32a618761 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -1,8 +1,10 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MonoLocalBinds #-} -{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE GADTs #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | @@ -16,16 +18,21 @@ -- module Data.Array.Accelerate.Classes.Vector where +import Data.Kind import GHC.TypeLits -import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Smart import Data.Primitive.Vec -class Vectoring a b c | a -> b where - indexAt :: a -> c -> b +class Vectoring vector a | vector -> a where + type IndexType vector :: Type + vecIndex :: vector -> IndexType vector -> a + vecEmpty :: vector -instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) (Exp Int) where - indexAt = mkVectorIndex + +instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where + type IndexType (Exp (Vec n a)) = Exp Int + vecIndex = mkVectorIndex + vecEmpty = mkVectorCreate diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 14c043d1f..ab6650300 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -73,6 +73,7 @@ module Data.Array.Accelerate.Smart ( mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..), -- ** Smart constructors for vector operations + mkVectorCreate, mkVectorIndex, -- ** Auxiliary functions @@ -865,7 +866,7 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where Case{} -> internalError "encountered empty case" Cond _ e _ -> typeR e While t _ _ _ -> t - PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c + PrimConst c -> TupRsingle $ primConstType c PrimApp f _ -> snd $ primFunType f Index tp _ _ -> tp LinearIndex tp _ _ -> tp @@ -1179,6 +1180,11 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil x = SmartExp $ Prj PairIdxLeft a -- Operators from Vec +mkVectorCreate :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) +mkVectorCreate = let n :: Int + n = fromIntegral $ natVal $ Proxy @n + in mkExp $ PrimConst $ PrimVectorCreate $ VectorType n singleType + mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a mkVectorIndex = let n :: Int n = fromIntegral $ natVal $ Proxy @n diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 93b0395c0..34b22ef13 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -32,8 +32,6 @@ module Data.Primitive.Vec ( Vec8, pattern Vec8, Vec16, pattern Vec16, - mkVec, - listOfVec, liftVec, From 2c90dd5300dc19fc23964e254b99ea6f002b56d3 Mon Sep 17 00:00:00 2001 From: Hugo Date: Tue, 26 Oct 2021 21:33:51 +0200 Subject: [PATCH 03/12] Add implementation of empty vector and indexing --- src/Data/Array/Accelerate.hs | 2 ++ src/Data/Array/Accelerate/AST.hs | 5 +-- src/Data/Array/Accelerate/Classes/Vector.hs | 6 ---- src/Data/Array/Accelerate/Interpreter.hs | 9 ++++++ src/Data/Array/Accelerate/Smart.hs | 2 +- src/Data/Array/Accelerate/Trafo/Algebra.hs | 3 ++ src/Data/Primitive/Vec.hs | 34 ++++++++++++++++++--- 7 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 8811695b8..e2543c6ae 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -311,6 +311,8 @@ module Data.Array.Accelerate ( -- ** SIMD vectors Vec, VecElt, Vectoring(..), + vecOfList, + listOfVec, -- ** Type classes -- *** Basic type classes diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index a07920466..066704093 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -149,6 +149,7 @@ import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Type import Data.Primitive.Vec +import Data.Primitive.Types import Control.DeepSeq import Data.Kind import Data.Maybe @@ -656,7 +657,7 @@ data PrimConst ty where PrimPi :: FloatingType a -> PrimConst a -- constant for empty Vec - PrimVectorCreate :: KnownNat n => VectorType (Vec n a) -> PrimConst (Vec n a) + PrimVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> PrimConst (Vec n a) -- |Primitive scalar operations @@ -752,7 +753,7 @@ data PrimFun sig where PrimLNot :: PrimFun (PrimBool -> PrimBool) -- local array operators - PrimVectorIndex :: KnownNat n => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) + PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 32a618761..0ab3c4942 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -24,12 +24,6 @@ import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Smart import Data.Primitive.Vec -class Vectoring vector a | vector -> a where - type IndexType vector :: Type - vecIndex :: vector -> IndexType vector -> a - vecEmpty :: vector - - instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where type IndexType (Exp (Vec n a)) = Exp Int vecIndex = mkVectorIndex diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 5b8e6401a..344a8691d 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -69,6 +69,7 @@ import qualified Data.Array.Accelerate.Sugar.Array as Sugar import qualified Data.Array.Accelerate.Sugar.Elt as Sugar import qualified Data.Array.Accelerate.Trafo.Delayed as AST +import GHC.TypeLits import Control.DeepSeq import Control.Exception import Control.Monad @@ -1082,6 +1083,7 @@ evalPrimConst :: PrimConst a -> a evalPrimConst (PrimMinBound ty) = evalMinBound ty evalPrimConst (PrimMaxBound ty) = evalMaxBound ty evalPrimConst (PrimPi ty) = evalPi ty +evalPrimConst (PrimVectorCreate ty) = evalVectorCreate ty evalPrim :: PrimFun (a -> r) -> (a -> r) evalPrim (PrimAdd ty) = evalAdd ty @@ -1144,6 +1146,7 @@ evalPrim (PrimMin ty) = evalMin ty evalPrim PrimLAnd = evalLAnd evalPrim PrimLOr = evalLOr evalPrim PrimLNot = evalLNot +evalPrim (PrimVectorIndex v i) = evalVectorIndex v i evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb evalPrim (PrimToFloating ta tb) = evalToFloating ta tb @@ -1168,6 +1171,9 @@ evalLOr (x, y) = fromBool (toBool x || toBool y) evalLNot :: PrimBool -> PrimBool evalLNot = fromBool . not . toBool +evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a +evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i) + evalFromIntegral :: IntegralType a -> NumType b -> a -> b evalFromIntegral ta (IntegralNumType tb) | IntegralDict <- integralDict ta @@ -1213,6 +1219,9 @@ evalMaxBound (IntegralBoundedType ty) evalPi :: FloatingType a -> a evalPi ty | FloatingDict <- floatingDict ty = pi +evalVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> Vec n a +evalVectorCreate (VectorType n _) = vecEmpty + evalSin :: FloatingType a -> (a -> a) evalSin ty | FloatingDict <- floatingDict ty = sin diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index ab6650300..7693ebf45 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1,5 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE CPP #-} + {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index 9cfea36ae..1e620435b 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -33,12 +33,14 @@ import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Pretty.Print ( primOperator, isInfix, opName ) import Data.Array.Accelerate.Trafo.Environment import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Classes.Vector import qualified Data.Array.Accelerate.Debug.Internal.Stats as Stats import Data.Bits import Data.Monoid import Data.Text ( Text ) +import Data.Primitive.Vec import Data.Text.Prettyprint.Doc import Data.Text.Prettyprint.Doc.Render.Text import GHC.Float ( float2Double, double2Float ) @@ -142,6 +144,7 @@ evalPrimApp env f x PrimNEq ty -> evalNEq ty x env PrimMax ty -> evalMax ty x env PrimMin ty -> evalMin ty x env + PrimVectorIndex _ _ -> Nothing PrimLAnd -> evalLAnd x env PrimLOr -> evalLOr x env PrimLNot -> evalLNot x env diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 34b22ef13..10930d4e4 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -5,12 +5,16 @@ {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE FlexibleInstances #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec @@ -33,10 +37,13 @@ module Data.Primitive.Vec ( Vec16, pattern Vec16, listOfVec, + vecOfList, liftVec, + Vectoring(..) ) where +import Data.Kind import Data.Proxy import Control.Monad.ST import Control.Monad.Reader @@ -86,14 +93,25 @@ import GHC.Word -- data Vec (n :: Nat) a = Vec ByteArray# -mkVec :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a -mkVec vs = runST $ do +class Vectoring vector a | vector -> a where + type IndexType vector :: Data.Kind.Type + vecIndex :: vector -> IndexType vector -> a + vecEmpty :: vector + +instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where + type IndexType (Vec n a) = Int + vecIndex (Vec ba#) (I# i#) = indexByteArray# ba# i# + vecEmpty = mkVec + + +mkVec :: forall n a. (KnownNat n, Prim a) => Vec n a +mkVec = runST $ do let n :: Int = fromIntegral $ natVal $ Proxy @n mba <- newByteArray (n * sizeOf (undefined :: a)) - zipWithM_ (writeByteArray mba) [0..n] vs ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# + type role Vec nominal representational instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where @@ -104,6 +122,14 @@ instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where . group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", " . map viaShow +vecOfList :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a +vecOfList vs = runST $ do + let n :: Int = fromIntegral $ natVal $ Proxy @n + mba <- newByteArray (n * sizeOf (undefined :: a)) + zipWithM_ (writeByteArray mba) [0..n] vs + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + listOfVec :: forall a n. (Prim a, KnownNat n) => Vec n a -> [a] listOfVec (Vec ba#) = go 0# where From 74feaecf673bc615e8464f3a68f102ff721c3f8d Mon Sep 17 00:00:00 2001 From: Hugo Date: Wed, 27 Oct 2021 14:22:34 +0200 Subject: [PATCH 04/12] Add bounds check on vector index --- src/Data/Primitive/Vec.hs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 10930d4e4..3fa13bf09 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -100,7 +100,10 @@ class Vectoring vector a | vector -> a where instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where type IndexType (Vec n a) = Int - vecIndex (Vec ba#) (I# i#) = indexByteArray# ba# i# + vecIndex (Vec ba#) i@(I# iu#) = let + n :: Int + n = fromIntegral $ natVal $ Proxy @n + in if i >= 0 && i < n then indexByteArray# ba# iu# else error ("index " <> show i <> " out of range in Vec of size " <> show n) vecEmpty = mkVec From 977669bd9c286a446259036689bb00dc5bc28e59 Mon Sep 17 00:00:00 2001 From: Hugo Date: Thu, 28 Oct 2021 20:55:15 +0200 Subject: [PATCH 05/12] Fix vector creation (todo delete the prim const) --- src/Data/Array/Accelerate/Analysis/Hash.hs | 2 ++ src/Data/Array/Accelerate/Classes/Vector.hs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 75625b9ec..8587742cc 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -389,6 +389,7 @@ encodePrimConst :: PrimConst c -> Builder encodePrimConst (PrimMinBound t) = intHost $(hashQ "PrimMinBound") <> encodeBoundedType t encodePrimConst (PrimMaxBound t) = intHost $(hashQ "PrimMaxBound") <> encodeBoundedType t encodePrimConst (PrimPi t) = intHost $(hashQ "PrimPi") <> encodeFloatingType t +encodePrimConst (PrimVectorCreate t) = intHost $(hashQ "PrimVectorCreate") <> encodeVectorType t encodePrimFun :: PrimFun f -> Builder encodePrimFun (PrimAdd a) = intHost $(hashQ "PrimAdd") <> encodeNumType a @@ -448,6 +449,7 @@ encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq") encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeSingleType a encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a +encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b) encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd") diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 0ab3c4942..1eef95abf 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -27,6 +27,6 @@ import Data.Primitive.Vec instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where type IndexType (Exp (Vec n a)) = Exp Int vecIndex = mkVectorIndex - vecEmpty = mkVectorCreate + vecEmpty = undef From dc7d849f1256ea88cccb5437bf86d08b10f8fb79 Mon Sep 17 00:00:00 2001 From: Hugo Date: Tue, 2 Nov 2021 12:24:45 +0100 Subject: [PATCH 06/12] implement interpreter and fix bugs --- src/Data/Array/Accelerate/AST.hs | 1 + src/Data/Array/Accelerate/Analysis/Hash.hs | 1 + src/Data/Array/Accelerate/Classes/Vector.hs | 1 + src/Data/Array/Accelerate/Interpreter.hs | 4 ++++ src/Data/Array/Accelerate/Smart.hs | 9 +++++++++ src/Data/Array/Accelerate/Trafo/Algebra.hs | 1 + src/Data/Primitive/Vec.hs | 10 ++++++++++ 7 files changed, 27 insertions(+) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 066704093..3952c9c60 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -754,6 +754,7 @@ data PrimFun sig where -- local array operators PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) + PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a) -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 8587742cc..f7b22e47f 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -450,6 +450,7 @@ encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b) +encodePrimFun (PrimVectorWrite (VectorType _ a) b) = intHost $(hashQ "PrimVectorWrite") <> encodeSingleType a <> encodeNumType (IntegralNumType b) encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd") diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 1eef95abf..87586985d 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -27,6 +27,7 @@ import Data.Primitive.Vec instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where type IndexType (Exp (Vec n a)) = Exp Int vecIndex = mkVectorIndex + vecWrite = mkVectorWrite vecEmpty = undef diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 344a8691d..06f184348 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1147,6 +1147,7 @@ evalPrim PrimLAnd = evalLAnd evalPrim PrimLOr = evalLOr evalPrim PrimLNot = evalLNot evalPrim (PrimVectorIndex v i) = evalVectorIndex v i +evalPrim (PrimVectorWrite v i) = evalVectorWrite v i evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb evalPrim (PrimToFloating ta tb) = evalToFloating ta tb @@ -1174,6 +1175,9 @@ evalLNot = fromBool . not . toBool evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i) +evalVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, (i, a)) -> Vec n a +evalVectorWrite (VectorType n _) ti (v, (i, a)) | IntegralDict <- integralDict ti = vecWrite v (fromIntegral i) a + evalFromIntegral :: IntegralType a -> NumType b -> a -> b evalFromIntegral ta (IntegralNumType tb) | IntegralDict <- integralDict ta diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 7693ebf45..4da5568ad 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -75,6 +75,7 @@ module Data.Array.Accelerate.Smart ( -- ** Smart constructors for vector operations mkVectorCreate, mkVectorIndex, + mkVectorWrite, -- ** Auxiliary functions ($$), ($$$), ($$$$), ($$$$$), @@ -1190,6 +1191,11 @@ mkVectorIndex = let n :: Int n = fromIntegral $ natVal $ Proxy @n in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType +mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a) +mkVectorWrite = let n :: Int + n = fromIntegral $ natVal $ Proxy @n + in mkPrimTernary $ PrimVectorWrite @n (VectorType n singleType) integralType + -- Numeric conversions mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b @@ -1277,6 +1283,9 @@ mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b) +mkPrimTernary :: (Elt a, Elt b, Elt c, Elt d) => PrimFun ((EltR a, (EltR b, EltR c)) -> EltR d) -> Exp a -> Exp b -> Exp c -> Exp d +mkPrimTernary prim (Exp a) (Exp b) (Exp c) = mkExp $ PrimApp prim (SmartExp $ Pair a (SmartExp (Pair b c))) + mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index 1e620435b..d8a655b06 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -145,6 +145,7 @@ evalPrimApp env f x PrimMax ty -> evalMax ty x env PrimMin ty -> evalMin ty x env PrimVectorIndex _ _ -> Nothing + PrimVectorWrite _ _ -> Nothing PrimLAnd -> evalLAnd x env PrimLOr -> evalLOr x env PrimLNot -> evalLNot x env diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 3fa13bf09..36c4f9570 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -15,6 +15,7 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TupleSections #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec @@ -96,6 +97,7 @@ data Vec (n :: Nat) a = Vec ByteArray# class Vectoring vector a | vector -> a where type IndexType vector :: Data.Kind.Type vecIndex :: vector -> IndexType vector -> a + vecWrite :: vector -> IndexType vector -> a -> vector vecEmpty :: vector instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where @@ -104,6 +106,14 @@ instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where n :: Int n = fromIntegral $ natVal $ Proxy @n in if i >= 0 && i < n then indexByteArray# ba# iu# else error ("index " <> show i <> " out of range in Vec of size " <> show n) + vecWrite vec@(Vec ba#) i@(I# iu#) v = runST $ do + let n :: Int + n = fromIntegral $ natVal $ Proxy @n + mba <- newByteArray (n * sizeOf (undefined :: a)) + let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n] (listOfVec vec) + zipWithM_ (writeByteArray mba) [0..n] new_vs + ByteArray nba# <- unsafeFreezeByteArray mba + return $! Vec nba# vecEmpty = mkVec From 3fe1e808ebe1ae8bbe17dc5203b82d812f8c026c Mon Sep 17 00:00:00 2001 From: Hugo Date: Wed, 3 Nov 2021 16:01:04 +0100 Subject: [PATCH 07/12] Remove vector create constant --- src/Data/Array/Accelerate/AST.hs | 7 ------- src/Data/Array/Accelerate/Analysis/Hash.hs | 1 - src/Data/Array/Accelerate/Interpreter.hs | 1 - src/Data/Array/Accelerate/Smart.hs | 6 ------ 4 files changed, 15 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 3952c9c60..6b0f83d24 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -656,10 +656,6 @@ data PrimConst ty where -- constant from Floating PrimPi :: FloatingType a -> PrimConst a - -- constant for empty Vec - PrimVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> PrimConst (Vec n a) - - -- |Primitive scalar operations -- data PrimFun sig where @@ -847,7 +843,6 @@ primConstType = \case PrimMinBound t -> bounded t PrimMaxBound t -> bounded t PrimPi t -> floating t - PrimVectorCreate t -> vector t where bounded :: BoundedType a -> ScalarType a bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t @@ -1122,7 +1117,6 @@ rnfPrimConst :: PrimConst c -> () rnfPrimConst (PrimMinBound t) = rnfBoundedType t rnfPrimConst (PrimMaxBound t) = rnfBoundedType t rnfPrimConst (PrimPi t) = rnfFloatingType t -rnfPrimConst (PrimVectorCreate t) = rnfVectorType t rnfPrimFun :: PrimFun f -> () rnfPrimFun (PrimAdd t) = rnfNumType t @@ -1350,7 +1344,6 @@ liftPrimConst :: PrimConst c -> CodeQ (PrimConst c) liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] -liftPrimConst (PrimVectorCreate t) = [|| PrimVectorCreate $$(liftVectorType t) ||] liftPrimFun :: PrimFun f -> CodeQ (PrimFun f) liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||] diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index f7b22e47f..2b399aa46 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -389,7 +389,6 @@ encodePrimConst :: PrimConst c -> Builder encodePrimConst (PrimMinBound t) = intHost $(hashQ "PrimMinBound") <> encodeBoundedType t encodePrimConst (PrimMaxBound t) = intHost $(hashQ "PrimMaxBound") <> encodeBoundedType t encodePrimConst (PrimPi t) = intHost $(hashQ "PrimPi") <> encodeFloatingType t -encodePrimConst (PrimVectorCreate t) = intHost $(hashQ "PrimVectorCreate") <> encodeVectorType t encodePrimFun :: PrimFun f -> Builder encodePrimFun (PrimAdd a) = intHost $(hashQ "PrimAdd") <> encodeNumType a diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 06f184348..c304051ed 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1083,7 +1083,6 @@ evalPrimConst :: PrimConst a -> a evalPrimConst (PrimMinBound ty) = evalMinBound ty evalPrimConst (PrimMaxBound ty) = evalMaxBound ty evalPrimConst (PrimPi ty) = evalPi ty -evalPrimConst (PrimVectorCreate ty) = evalVectorCreate ty evalPrim :: PrimFun (a -> r) -> (a -> r) evalPrim (PrimAdd ty) = evalAdd ty diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 4da5568ad..30981c660 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -73,7 +73,6 @@ module Data.Array.Accelerate.Smart ( mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..), -- ** Smart constructors for vector operations - mkVectorCreate, mkVectorIndex, mkVectorWrite, @@ -1181,11 +1180,6 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil x = SmartExp $ Prj PairIdxLeft a -- Operators from Vec -mkVectorCreate :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -mkVectorCreate = let n :: Int - n = fromIntegral $ natVal $ Proxy @n - in mkExp $ PrimConst $ PrimVectorCreate $ VectorType n singleType - mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a mkVectorIndex = let n :: Int n = fromIntegral $ natVal $ Proxy @n From faa139bcdb58fd3f457556508153c4bb2438989d Mon Sep 17 00:00:00 2001 From: Hugo Date: Thu, 4 Nov 2021 20:29:50 +0100 Subject: [PATCH 08/12] add missing pattern match and module in cabal file --- accelerate.cabal | 1 + src/Data/Array/Accelerate/AST.hs | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/accelerate.cabal b/accelerate.cabal index 0b95607e4..2e64e1e1f 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -402,6 +402,7 @@ library Data.Array.Accelerate.Classes.RealFloat Data.Array.Accelerate.Classes.RealFrac Data.Array.Accelerate.Classes.ToFloating + Data.Array.Accelerate.Classes.Vector Data.Array.Accelerate.Debug.Internal.Clock Data.Array.Accelerate.Debug.Internal.Flags Data.Array.Accelerate.Debug.Internal.Graph diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 6b0f83d24..242d015af 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -937,6 +937,11 @@ primFunType = \case i = integral i' in (v `TupRpair` i, single a) + PrimVectorWrite v'@(VectorType _ a) i' -> + let v = singleVector v' + i = integral i' + in (v `TupRpair` (i `TupRpair` single a), v) + -- general conversion between types PrimFromIntegral a b -> unary (integral a) (num b) PrimToFloating a b -> unary (num a) (floating b) From 0e250b8a05494a6f7aff4561add31c62d5321d38 Mon Sep 17 00:00:00 2001 From: Hugo Date: Thu, 2 Dec 2021 16:04:12 +0100 Subject: [PATCH 09/12] Move vec operations to correct AST --- src/Data/Array/Accelerate/AST.hs | 47 ++++----- src/Data/Array/Accelerate/Analysis/Hash.hs | 4 +- src/Data/Array/Accelerate/Classes/Vector.hs | 5 +- src/Data/Array/Accelerate/Interpreter.hs | 2 - .../Array/Accelerate/Representation/Vec.hs | 4 + src/Data/Array/Accelerate/Smart.hs | 30 ++++-- src/Data/Array/Accelerate/Trafo/Algebra.hs | 2 - src/Data/Array/Accelerate/Trafo/Sharing.hs | 68 +++++++------ src/Data/Array/Accelerate/Trafo/Shrink.hs | 6 ++ src/Data/Array/Accelerate/Trafo/Simplify.hs | 4 + .../Array/Accelerate/Trafo/Substitution.hs | 96 ++++++++++--------- 11 files changed, 155 insertions(+), 113 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 242d015af..31b2512ad 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -9,6 +9,8 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE AllowAmbiguousTypes #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.AST @@ -149,7 +151,6 @@ import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Type import Data.Primitive.Vec -import Data.Primitive.Types import Control.DeepSeq import Data.Kind import Data.Maybe @@ -560,6 +561,21 @@ data OpenExp env aenv t where -> OpenExp env aenv (Vec n s) -> OpenExp env aenv tup + VecIndex :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> OpenExp env aenv (Vec n s) + -> OpenExp env aenv i + -> OpenExp env aenv s + + VecWrite :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> OpenExp env aenv (Vec n s) + -> OpenExp env aenv i + -> OpenExp env aenv s + -> OpenExp env aenv (Vec n s) + -- Array indices & shapes IndexSlice :: SliceIndex slix sl co sh -> OpenExp env aenv slix @@ -748,10 +764,6 @@ data PrimFun sig where PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool) PrimLNot :: PrimFun (PrimBool -> PrimBool) - -- local array operators - PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a) - PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a) - -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b) @@ -818,6 +830,8 @@ expType = \case Nil -> TupRunit VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR VecUnpack vecR _ -> vecRtuple vecR + VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s + VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT IndexSlice si _ _ -> shapeType $ sliceShapeR si IndexFull si _ _ -> shapeType $ sliceDomainR si ToIndex{} -> TupRsingle scalarTypeInt @@ -850,9 +864,6 @@ primConstType = \case floating :: FloatingType t -> ScalarType t floating = SingleScalarType . NumSingleType . FloatingNumType - vector :: forall n a. (KnownNat n) => VectorType (Vec n a) -> ScalarType (Vec n a) - vector = VectorScalarType - primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b) primFunType = \case -- Num @@ -931,17 +942,6 @@ primFunType = \case PrimLOr -> binary' tbool PrimLNot -> unary' tbool --- Local Vector operations - PrimVectorIndex v'@(VectorType _ a) i' -> - let v = singleVector v' - i = integral i' - in (v `TupRpair` i, single a) - - PrimVectorWrite v'@(VectorType _ a) i' -> - let v = singleVector v' - i = integral i' - in (v `TupRpair` (i `TupRpair` single a), v) - -- general conversion between types PrimFromIntegral a b -> unary (integral a) (num b) PrimToFloating a b -> unary (num a) (floating b) @@ -954,7 +954,6 @@ primFunType = \case compare' a = binary (single a) tbool single = TupRsingle . SingleScalarType - singleVector = TupRsingle . VectorScalarType num = TupRsingle . SingleScalarType . NumSingleType integral = num . IntegralNumType floating = num . FloatingNumType @@ -1092,6 +1091,8 @@ rnfOpenExp topExp = Nil -> () VecPack vecr e -> rnfVecR vecr `seq` rnfE e VecUnpack vecr e -> rnfVecR vecr `seq` rnfE e + VecIndex vt it v i -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i + VecWrite vt it v i e -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i `seq` rnfE e IndexSlice slice slix sh -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix @@ -1184,7 +1185,6 @@ rnfPrimFun (PrimMin t) = rnfSingleType t rnfPrimFun PrimLAnd = () rnfPrimFun PrimLOr = () rnfPrimFun PrimLNot = () -rnfPrimFun (PrimVectorIndex v i) = rnfVectorType v `seq` rnfIntegralType i rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f @@ -1313,6 +1313,8 @@ liftOpenExp pexp = Nil -> [|| Nil ||] VecPack vecr e -> [|| VecPack $$(liftVecR vecr) $$(liftE e) ||] VecUnpack vecr e -> [|| VecUnpack $$(liftVecR vecr) $$(liftE e) ||] + VecIndex vt it v i -> [|| VecIndex $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) ||] + VecWrite vt it v i e -> [|| VecWrite $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) $$(liftE e) ||] IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||] IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||] ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] @@ -1411,7 +1413,6 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||] liftPrimFun PrimLAnd = [|| PrimLAnd ||] liftPrimFun PrimLOr = [|| PrimLOr ||] liftPrimFun PrimLNot = [|| PrimLNot ||] -liftPrimFun (PrimVectorIndex v i) = [|| PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||] liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] @@ -1461,6 +1462,8 @@ formatExpOp = later $ \case Nil{} -> "Nil" VecPack{} -> "VecPack" VecUnpack{} -> "VecUnpack" + VecIndex{} -> "VecIndex" + VecWrite{} -> "VecWrite" IndexSlice{} -> "IndexSlice" IndexFull{} -> "IndexFull" ToIndex{} -> "ToIndex" diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 2b399aa46..964a5f11a 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -320,6 +320,8 @@ encodeOpenExp exp = Pair e1 e2 -> intHost $(hashQ "Pair") <> travE e1 <> travE e2 VecPack _ e -> intHost $(hashQ "VecPack") <> travE e VecUnpack _ e -> intHost $(hashQ "VecUnpack") <> travE e + VecIndex _ _ v i -> intHost $(hashQ "VecIndex") <> travE v <> travE i + VecWrite _ _ v i e -> intHost $(hashQ "VecWrite") <> travE v <> travE i <> travE e Const tp c -> intHost $(hashQ "Const") <> encodeScalarConst tp c Undef tp -> intHost $(hashQ "Undef") <> encodeScalarType tp IndexSlice spec ix sh -> intHost $(hashQ "IndexSlice") <> travE ix <> travE sh <> encodeSliceIndex spec @@ -448,8 +450,6 @@ encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq") encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeSingleType a encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a -encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b) -encodePrimFun (PrimVectorWrite (VectorType _ a) b) = intHost $(hashQ "PrimVectorWrite") <> encodeSingleType a <> encodeNumType (IntegralNumType b) encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd") diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs index 87586985d..21c7a7be2 100644 --- a/src/Data/Array/Accelerate/Classes/Vector.hs +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -5,6 +5,8 @@ {-# LANGUAGE MonoLocalBinds #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE GADTs #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | @@ -18,12 +20,13 @@ -- module Data.Array.Accelerate.Classes.Vector where -import Data.Kind import GHC.TypeLits import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Smart import Data.Primitive.Vec + + instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where type IndexType (Exp (Vec n a)) = Exp Int vecIndex = mkVectorIndex diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index c304051ed..aee68443f 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1145,8 +1145,6 @@ evalPrim (PrimMin ty) = evalMin ty evalPrim PrimLAnd = evalLAnd evalPrim PrimLOr = evalLOr evalPrim PrimLNot = evalLNot -evalPrim (PrimVectorIndex v i) = evalVectorIndex v i -evalPrim (PrimVectorWrite v i) = evalVectorWrite v i evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb evalPrim (PrimToFloating ta tb) = evalToFloating ta tb diff --git a/src/Data/Array/Accelerate/Representation/Vec.hs b/src/Data/Array/Accelerate/Representation/Vec.hs index 35eac3b6c..bd37c7f18 100644 --- a/src/Data/Array/Accelerate/Representation/Vec.hs +++ b/src/Data/Array/Accelerate/Representation/Vec.hs @@ -41,6 +41,7 @@ data VecR (n :: Nat) single tuple where VecRnil :: SingleType s -> VecR 0 s () VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s) + vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s) vecRvector = uncurry VectorType . go where @@ -48,6 +49,9 @@ vecRvector = uncurry VectorType . go go (VecRnil tp) = (0, tp) go (VecRsucc vec) | (n, tp) <- go vec = (n + 1, tp) +vecRSingle :: KnownNat n => VecR n s tuple -> SingleType s +vecRSingle vecr = let (VectorType _ s) = vecRvector vecr in s + vecRtuple :: VecR n s tuple -> TypeR tuple vecRtuple = snd . go where diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 30981c660..ccb38e7ab 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -527,6 +527,21 @@ data PreSmartExp acc exp t where -> exp (Vec n s) -> PreSmartExp acc exp tup + VecIndex :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> exp (Vec n s) + -> exp i + -> PreSmartExp acc exp s + + VecWrite :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> exp (Vec n s) + -> exp i + -> exp s + -> PreSmartExp acc exp (Vec n s) + ToIndex :: ShapeR sh -> exp sh -> exp sh @@ -860,6 +875,8 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where Prj _ _ -> error "I never joke about my work" VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR VecUnpack vecR _ -> vecRtuple vecR + VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s + VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT ToIndex _ _ _ -> TupRsingle scalarTypeInt FromIndex shr _ _ -> shapeType shr Case _ ((_,c):_) -> typeR c @@ -1179,16 +1196,15 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil where x = SmartExp $ Prj PairIdxLeft a --- Operators from Vec + +inferNat :: forall n. KnownNat n => Int +inferNat = fromInteger $ natVal (Proxy @n) + mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -mkVectorIndex = let n :: Int - n = fromIntegral $ natVal $ Proxy @n - in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType +mkVectorIndex (Exp v) (Exp i) = mkExp $ VecIndex (VectorType (inferNat @n) singleType) integralType v i mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a) -mkVectorWrite = let n :: Int - n = fromIntegral $ natVal $ Proxy @n - in mkPrimTernary $ PrimVectorWrite @n (VectorType n singleType) integralType +mkVectorWrite (Exp v) (Exp i) (Exp el) = mkExp $ VecWrite (VectorType (inferNat @n) singleType) integralType v i el -- Numeric conversions diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index d8a655b06..807ffe474 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -144,8 +144,6 @@ evalPrimApp env f x PrimNEq ty -> evalNEq ty x env PrimMax ty -> evalMax ty x env PrimMin ty -> evalMin ty x env - PrimVectorIndex _ _ -> Nothing - PrimVectorWrite _ _ -> Nothing PrimLAnd -> evalLAnd x env PrimLOr -> evalLOr x env PrimLNot -> evalLNot x env diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index 67ead04f0..9a740cb06 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -764,6 +764,8 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp Pair e1 e2 -> AST.Pair (cvt e1) (cvt e2) VecPack vec e -> AST.VecPack vec (cvt e) VecUnpack vec e -> AST.VecUnpack vec (cvt e) + VecIndex vt it v i -> AST.VecIndex vt it (cvt v) (cvt i) + VecWrite vt it v i e -> AST.VecWrite vt it (cvt v) (cvt i) (cvt e) ToIndex shr sh ix -> AST.ToIndex shr (cvt sh) (cvt ix) FromIndex shr sh e -> AST.FromIndex shr (cvt sh) (cvt e) Case e rhs -> cvtCase (cvt e) (over (mapped . _2) cvt rhs) @@ -1841,37 +1843,39 @@ makeOccMapSharingExp config accOccMap expOccMap = travE return (UnscopedExp [] (ExpSharing (StableNameHeight sn height) exp), height) reconstruct $ case pexp of - Tag tp i -> return (Tag tp i, 0) -- height is 0! - Const tp c -> return (Const tp c, 1) - Undef tp -> return (Undef tp, 1) - Nil -> return (Nil, 1) - Pair e1 e2 -> travE2 Pair e1 e2 - Prj i e -> travE1 (Prj i) e - VecPack vec e -> travE1 (VecPack vec) e - VecUnpack vec e -> travE1 (VecUnpack vec) e - ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix - FromIndex shr sh e -> travE2 (FromIndex shr) sh e - Match t e -> travE1 (Match t) e - Case e rhs -> do - (e', h1) <- travE lvl e - (rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ] - return (Case e' rhs', h1 `max` maximum h2 + 1) - Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 - While t p iter init -> do - (p' , h1) <- traverseFun1 lvl t p - (iter', h2) <- traverseFun1 lvl t iter - (init', h3) <- travE lvl init - return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) - PrimConst c -> return (PrimConst c, 1) - PrimApp p e -> travE1 (PrimApp p) e - Index tp a e -> travAE (Index tp) a e - LinearIndex tp a i -> travAE (LinearIndex tp) a i - Shape shr a -> travA (Shape shr) a - ShapeSize shr e -> travE1 (ShapeSize shr) e - Foreign tp ff f e -> do - (e', h) <- travE lvl e - return (Foreign tp ff f e', h+1) - Coerce t1 t2 e -> travE1 (Coerce t1 t2) e + Tag tp i -> return (Tag tp i, 0) -- height is 0! + Const tp c -> return (Const tp c, 1) + Undef tp -> return (Undef tp, 1) + Nil -> return (Nil, 1) + Pair e1 e2 -> travE2 Pair e1 e2 + Prj i e -> travE1 (Prj i) e + VecPack vec e -> travE1 (VecPack vec) e + VecUnpack vec e -> travE1 (VecUnpack vec) e + VecIndex vt ti v i -> travE2 (VecIndex vt ti) v i + VecWrite vt ti v i e -> travE3 (VecWrite vt ti) v i e + ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix + FromIndex shr sh e -> travE2 (FromIndex shr) sh e + Match t e -> travE1 (Match t) e + Case e rhs -> do + (e', h1) <- travE lvl e + (rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ] + return (Case e' rhs', h1 `max` maximum h2 + 1) + Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 + While t p iter init -> do + (p' , h1) <- traverseFun1 lvl t p + (iter', h2) <- traverseFun1 lvl t iter + (init', h3) <- travE lvl init + return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) + PrimConst c -> return (PrimConst c, 1) + PrimApp p e -> travE1 (PrimApp p) e + Index tp a e -> travAE (Index tp) a e + LinearIndex tp a i -> travAE (LinearIndex tp) a i + Shape shr a -> travA (Shape shr) a + ShapeSize shr e -> travE1 (ShapeSize shr) e + Foreign tp ff f e -> do + (e', h) <- travE lvl e + return (Foreign tp ff f e', h+1) + Coerce t1 t2 e -> travE1 (Coerce t1 t2) e where traverseAcc :: HasCallStack => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) @@ -2755,6 +2759,8 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp Prj i e -> travE1 (Prj i) e VecPack vec e -> travE1 (VecPack vec) e VecUnpack vec e -> travE1 (VecUnpack vec) e + VecIndex vt it v i -> travE2 (VecIndex vt it) v i + VecWrite vt it v i e -> travE3 (VecWrite vt it) v i e ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix FromIndex shr sh e -> travE2 (FromIndex shr) sh e Match t e -> travE1 (Match t) e diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index 574747865..636043113 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -293,6 +293,8 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE Pair x y -> Pair <$> shrinkE x <*> shrinkE y VecPack vec e -> VecPack vec <$> shrinkE e VecUnpack vec e -> VecUnpack vec <$> shrinkE e + VecIndex vt it v i -> VecIndex vt it <$> shrinkE v <*> shrinkE i + VecWrite vt it v i e -> VecWrite vt it <$> shrinkE v <*> shrinkE i <*> shrinkE e IndexSlice x ix sh -> IndexSlice x <$> shrinkE ix <*> shrinkE sh IndexFull x ix sl -> IndexFull x <$> shrinkE ix <*> shrinkE sl ToIndex shr sh ix -> ToIndex shr <$> shrinkE sh <*> shrinkE ix @@ -494,6 +496,8 @@ usesOfExp range = countE Pair e1 e2 -> countE e1 <> countE e2 VecPack _ e -> countE e VecUnpack _ e -> countE e + VecIndex _ _ v i -> countE v <> countE i + VecWrite _ _ v i e -> countE v <> countE i <> countE e IndexSlice _ ix sh -> countE ix <> countE sh IndexFull _ ix sl -> countE ix <> countE sl FromIndex _ sh i -> countE sh <> countE i @@ -581,6 +585,8 @@ usesOfPreAcc withShape countAcc idx = count Pair x y -> countE x + countE y VecPack _ e -> countE e VecUnpack _ e -> countE e + VecIndex _ _ v i -> countE v + countE i + VecWrite _ _ v i e -> countE v + countE i + countE e IndexSlice _ ix sh -> countE ix + countE sh IndexFull _ ix sl -> countE ix + countE sl ToIndex _ sh ix -> countE sh + countE ix diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index 71be5aad3..6fe611f7a 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -226,6 +226,8 @@ simplifyOpenExp env = first getAny . cvtE Pair e1 e2 -> Pair <$> cvtE e1 <*> cvtE e2 VecPack vec e -> VecPack vec <$> cvtE e VecUnpack vec e -> VecUnpack vec <$> cvtE e + VecIndex vt it v i -> VecIndex vt it <$> cvtE v <*> cvtE i + VecWrite vt it v i e -> VecWrite vt it <$> cvtE v <*> cvtE i <*> cvtE e IndexSlice x ix sh -> IndexSlice x <$> cvtE ix <*> cvtE sh IndexFull x ix sl -> IndexFull x <$> cvtE ix <*> cvtE sl ToIndex shr sh ix -> toIndex shr (cvtE sh) (cvtE ix) @@ -548,6 +550,8 @@ summariseOpenExp = (terms +~ 1) . goE Pair e1 e2 -> travE e1 +++ travE e2 & terms +~ 1 VecPack _ e -> travE e VecUnpack _ e -> travE e + VecIndex _ _ v i -> travE v +++ travE i + VecWrite _ _ v i e -> travE v +++ travE i +++ travE e IndexSlice _ slix sh -> travE slix +++ travE sh & terms +~ 1 -- +1 for sliceIndex IndexFull _ slix sl -> travE slix +++ travE sl & terms +~ 1 -- +1 for sliceIndex ToIndex _ sh ix -> travE sh +++ travE ix diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index e1aa1176b..7debd6d07 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -149,29 +149,31 @@ inlineVars lhsBound expr bound substitute k1 k2 vars topExp = case topExp of Let lhs e1 e2 | Exists lhs' <- rebuildLHS lhs - -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2 - Evar (Var t ix) -> Evar . Var t <$> k1 ix - Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 - Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 - Nil -> Just Nil - VecPack vec e1 -> VecPack vec <$> travE e1 - VecUnpack vec e1 -> VecUnpack vec <$> travE e1 - IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 - IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 - ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 - FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 - Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def - Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 - While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 - Const t c -> Just $ Const t c - PrimConst c -> Just $ PrimConst c - PrimApp p e1 -> PrimApp p <$> travE e1 - Index a e1 -> Index a <$> travE e1 - LinearIndex a e1 -> LinearIndex a <$> travE e1 - Shape a -> Just $ Shape a - ShapeSize shr e1 -> ShapeSize shr <$> travE e1 - Undef t -> Just $ Undef t - Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 + -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2 + Evar (Var t ix) -> Evar . Var t <$> k1 ix + Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 + Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 + Nil -> Just Nil + VecPack vec e1 -> VecPack vec <$> travE e1 + VecUnpack vec e1 -> VecUnpack vec <$> travE e1 + VecIndex vt it v i -> VecIndex vt it <$> travE v <*> travE i + VecWrite vt it v i e -> VecWrite vt it <$> travE v <*> travE i <*> travE e + IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 + IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 + ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 + FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 + Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def + Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 + While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 + Const t c -> Just $ Const t c + PrimConst c -> Just $ PrimConst c + PrimApp p e1 -> PrimApp p <$> travE e1 + Index a e1 -> Index a <$> travE e1 + LinearIndex a e1 -> LinearIndex a <$> travE e1 + Shape a -> Just $ Shape a + ShapeSize shr e1 -> ShapeSize shr <$> travE e1 + Undef t -> Just $ Undef t + Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 where travE :: OpenExp env1 aenv s -> Maybe (OpenExp env2 aenv s) @@ -546,31 +548,33 @@ rebuildOpenExp -> f (OpenExp env' aenv' t) rebuildOpenExp v av@(ReindexAvar reindex) exp = case exp of - Const t c -> pure $ Const t c - PrimConst c -> pure $ PrimConst c - Undef t -> pure $ Undef t - Evar var -> expOut <$> v var + Const t c -> pure $ Const t c + PrimConst c -> pure $ PrimConst c + Undef t -> pure $ Undef t + Evar var -> expOut <$> v var Let lhs a b | Exists lhs' <- rebuildLHS lhs - -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b - Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 - Nil -> pure Nil - VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e - VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e - IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh - IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl - ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix - FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix - Case e rhs def -> Case <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def - Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e - While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x - PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x - Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh - LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i - Shape a -> Shape <$> reindex a - ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh - Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e - Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e + -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b + Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 + Nil -> pure Nil + VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e + VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e + VecIndex vt it v' i -> VecIndex vt it <$> rebuildOpenExp v av v' <*> rebuildOpenExp v av i + VecWrite vt it v' i e -> VecWrite vt it <$> rebuildOpenExp v av v' <*> rebuildOpenExp v av i <*> rebuildOpenExp v av e + IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh + IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl + ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + Case e rhs def -> Case <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def + Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e + While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x + PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x + Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh + LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i + Shape a -> Shape <$> reindex a + ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh + Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e + Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e {-# INLINEABLE rebuildFun #-} rebuildFun From 21d6dab8fad5890675c2184d137a261b8e6ace21 Mon Sep 17 00:00:00 2001 From: Hugo Date: Wed, 8 Dec 2021 10:58:44 +0100 Subject: [PATCH 10/12] fix off by one errors --- src/Data/Primitive/Vec.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 36c4f9570..ff60d7d2e 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -110,8 +110,8 @@ instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where let n :: Int n = fromIntegral $ natVal $ Proxy @n mba <- newByteArray (n * sizeOf (undefined :: a)) - let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n] (listOfVec vec) - zipWithM_ (writeByteArray mba) [0..n] new_vs + let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n-1] (listOfVec vec) + zipWithM_ (writeByteArray mba) [0..n-1] new_vs ByteArray nba# <- unsafeFreezeByteArray mba return $! Vec nba# vecEmpty = mkVec @@ -139,7 +139,7 @@ vecOfList :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a vecOfList vs = runST $ do let n :: Int = fromIntegral $ natVal $ Proxy @n mba <- newByteArray (n * sizeOf (undefined :: a)) - zipWithM_ (writeByteArray mba) [0..n] vs + zipWithM_ (writeByteArray mba) [0..n-1] vs ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# From ad1f995dfa11b8b78f4dde01a95e4e05fb1ea4d3 Mon Sep 17 00:00:00 2001 From: Hugo Date: Mon, 13 Dec 2021 12:32:59 +0100 Subject: [PATCH 11/12] style changes --- src/Data/Array/Accelerate/AST.hs | 4 ++-- src/Data/Primitive/Vec.hs | 36 ++++++++++++++++---------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 31b2512ad..d3a26353e 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -7,10 +8,9 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE AllowAmbiguousTypes #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.AST diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index ff60d7d2e..52e5ccc39 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -1,21 +1,21 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE UnboxedTuples #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE FunctionalDependencies #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TupleSections #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TupleSections #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec From f9556e3c6c1dfbfe946563d274670ae849ab122c Mon Sep 17 00:00:00 2001 From: Hugo Date: Wed, 19 Jan 2022 16:08:39 +0100 Subject: [PATCH 12/12] prevent memcpy using unsafe mutable coercion --- src/Data/Primitive/Vec.hs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 52e5ccc39..a50f643c2 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -109,9 +109,8 @@ instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where vecWrite vec@(Vec ba#) i@(I# iu#) v = runST $ do let n :: Int n = fromIntegral $ natVal $ Proxy @n - mba <- newByteArray (n * sizeOf (undefined :: a)) - let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n-1] (listOfVec vec) - zipWithM_ (writeByteArray mba) [0..n-1] new_vs + mba <- unsafeThawByteArray (ByteArray ba#) + writeByteArray mba i v ByteArray nba# <- unsafeFreezeByteArray mba return $! Vec nba# vecEmpty = mkVec