-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tabulate
function
#18
Comments
Implementing |
I didn't realize that the idea was that we provide these functions before closing this issue; I'll reopen this and implement such a function. |
|
Supporting a fully general |
I think at the top you mean |
Here are the uses we have for it - note that the nested functors will eventually end up as 2-dimensional matrices: -- | zero-functor with a `one` at index `i`.
oneHot ::
( Rep.Representable f,
Num a,
Eq (Rep.Rep f)
) =>
Rep.Rep f ->
f a
oneHot i = Rep.tabulate (\i' -> if i == i' then 1 else 0)
-- | Generalized identity matrix.
idM ::
( Rep.Representable f,
Num a,
Eq (Rep.Rep f)
) =>
f (f a)
idM = Rep.tabulate oneHot And here's the other place: -- | Create a functor of all kernel positions of a convolution.
makeKernelPosFunctor ::
forall
dilationRate
kernelPosFunctor
windowFunctor
kernelFunctor.
( KnownNat dilationRate,
Integral (Rep.Rep windowFunctor),
Integral (Rep.Rep kernelPosFunctor),
Integral (Rep.Rep kernelFunctor),
Pointed kernelPosFunctor,
Zip.Zip kernelPosFunctor,
Rep.Representable windowFunctor,
Rep.Representable kernelFunctor,
Rep.Representable kernelPosFunctor,
Functor kernelPosFunctor
) =>
kernelPosFunctor (kernelFunctor (Rep.Rep windowFunctor))
makeKernelPosFunctor =
Rep.tabulate (makeIndexFunctor @windowFunctor @kernelPosFunctor dilationRate)
where
dilationRate = toInteger (natVal (Proxy @dilationRate))
-- | For an index (`idx`) of an output functor of a convolution, create the
-- corresponding kernel position, i.e. a `kernelFunctor` of indices of an input
-- functor.
-- Shifts all indices of `kenerlFunctor` by `offset` (i.e. the position of the
-- kernel's first entry) plus `dilationRate` times the indices.
makeIndexFunctor ::
( Rep.Representable windowInFunctor,
Rep.Representable windowOutFunctor,
Rep.Representable kernelFunctor,
Integral (Rep.Rep windowInFunctor),
Integral (Rep.Rep windowOutFunctor),
Integral (Rep.Rep kernelFunctor),
Integral dilationRate
) =>
dilationRate ->
Rep.Rep windowOutFunctor ->
kernelFunctor (Rep.Rep windowInFunctor)
makeIndexFunctor dilationRate idx = Rep.tabulate tabulator
where
offSet = fromIntegral idx
tabulator x =
fromIntegral (shiftIndex (fromIntegral dilationRate) offSet (fromIntegral x))
-- | Transform an index (`idx`) of a `kernelFunctor` to the corresponding index of
-- an input functor for a given kernel position (specified by `offSet`).
shiftIndex :: Integer -> Integer -> Integer -> Integer
shiftIndex dilationRate offSet idx = offSet + dilationRate * idx Does this help? |
We're not sure we need this one, so it's OK to defer. |
This corresponds to
ZipCat
inConCat
, which should correspond withgenerate
in Accelerate.The text was updated successfully, but these errors were encountered: