Skip to content

Commit

Permalink
feat: handle weight initializers for reactant RNGs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 7, 2025
1 parent 82349e7 commit 2077944
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
5 changes: 4 additions & 1 deletion lib/WeightInitializers/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.4"
version = "1.1.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand All @@ -18,6 +18,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
Expand All @@ -27,6 +28,7 @@ WeightInitializersChainRulesCoreExt = "ChainRulesCore"
WeightInitializersGPUArraysExt = "GPUArrays"
WeightInitializersMetalExt = ["Metal", "GPUArrays"]
WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"]
WeightInitializersReactantExt = "Reactant"

[compat]
AMDGPU = "0.9.6, 1"
Expand All @@ -39,6 +41,7 @@ GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
Metal = "1.3.0"
Random = "1.10"
Reactant = "0.2.16"
SpecialFunctions = "2.4"
Statistics = "1.10"
julia = "1.10"
Expand Down
23 changes: 23 additions & 0 deletions lib/WeightInitializers/ext/WeightInitializersReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module WeightInitializersReactantExt

using Reactant: Reactant, TracedUtils, TracedRNG, ConcreteRNG, TracedRArray
using WeightInitializers: DeviceAgnostic

# random numbers are automatically handled

function DeviceAgnostic.ones(::ConcreteRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return Reactant.to_rarray(ones(T, dims...))
end
function DeviceAgnostic.zeros(
::ConcreteRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return Reactant.to_rarray(zeros(T, dims...))
end

function DeviceAgnostic.ones(::TracedRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return TracedUtils.promote_to(TracedRArray{T, length(dims)}, ones(T, dims...))
end
function DeviceAgnostic.zeros(::TracedRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return TracedUtils.promote_to(TracedRArray{T, length(dims)}, zeros(T, dims...))
end

end

0 comments on commit 2077944

Please sign in to comment.