diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index bb39b7955d..fbe612a00f 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "1.0.4" +version = "1.1.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -18,6 +18,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] @@ -27,6 +28,7 @@ WeightInitializersChainRulesCoreExt = "ChainRulesCore" WeightInitializersGPUArraysExt = "GPUArrays" WeightInitializersMetalExt = ["Metal", "GPUArrays"] WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] +WeightInitializersReactantExt = "Reactant" [compat] AMDGPU = "0.9.6, 1" @@ -39,6 +41,7 @@ GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" Metal = "1.3.0" Random = "1.10" +Reactant = "0.2.16" SpecialFunctions = "2.4" Statistics = "1.10" julia = "1.10" diff --git a/lib/WeightInitializers/ext/WeightInitializersReactantExt.jl b/lib/WeightInitializers/ext/WeightInitializersReactantExt.jl new file mode 100644 index 0000000000..980659e10c --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersReactantExt.jl @@ -0,0 +1,23 @@ +module WeightInitializersReactantExt + +using Reactant: Reactant, TracedUtils, TracedRNG, ConcreteRNG, TracedRArray +using WeightInitializers: DeviceAgnostic + +# random numbers are automatically handled + +function DeviceAgnostic.ones(::ConcreteRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return Reactant.to_rarray(ones(T, dims...)) +end +function DeviceAgnostic.zeros( + ::ConcreteRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return Reactant.to_rarray(zeros(T, dims...)) +end + +function DeviceAgnostic.ones(::TracedRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return TracedUtils.promote_to(TracedRArray{T, length(dims)}, ones(T, dims...)) +end +function DeviceAgnostic.zeros(::TracedRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return TracedUtils.promote_to(TracedRArray{T, length(dims)}, zeros(T, dims...)) +end + +end