diff --git a/NDTensors/ext/NDTensorsCUDAExt/fill.jl b/NDTensors/ext/NDTensorsCUDAExt/fill.jl index 984a7f8532..dafa2fe694 100644 --- a/NDTensors/ext/NDTensorsCUDAExt/fill.jl +++ b/NDTensors/ext/NDTensorsCUDAExt/fill.jl @@ -1,12 +1,12 @@ -import NDTensors.SetParameters: set_unspecified_parameters, DefaultParameters -import NDTensors: leaf_parenttype +using NDTensors.SetParameters: set_unspecified_parameters, DefaultParameters +using NDTensors: leaf_parenttype -function NDTensors.generic_randn(arraytype::Type{<:CuArray}, dim::Integer=0) - arraytype_specified = set_unspecified_parameters( - leaf_parenttype(arraytype), DefaultParameters() - ) - return CUDA.randn(eltype(arraytype_specified), dim) -end +# function NDTensors.generic_randn(arraytype::Type{<:CuArray}, dim::Integer=0) +# arraytype_specified = set_unspecified_parameters( +# leaf_parenttype(arraytype), DefaultParameters() +# ) +# return CUDA.randn(eltype(arraytype_specified), dim) +# end function NDTensors.generic_zeros(arraytype::Type{<:CuArray}, dim::Integer=0) arraytype_specified = NDTensors.set_unspecified_parameters( diff --git a/NDTensors/src/abstractarray/fill.jl b/NDTensors/src/abstractarray/fill.jl index 85e8377d3a..439a115dd8 100644 --- a/NDTensors/src/abstractarray/fill.jl +++ b/NDTensors/src/abstractarray/fill.jl @@ -1,13 +1,9 @@ -function generic_randn(arraytype::Type{<:AbstractArray}, dim::Integer=0) +function generic_randn(arraytype::Type{<:AbstractArray}, dim::Integer=0; rng = Random.default_rng()) arraytype_specified = set_unspecified_parameters( leaf_parenttype(arraytype), DefaultParameters() ) data = similar(arraytype_specified, dim) - ElT = eltype(data) - for i in 1:length(data) - data[i] = randn(ElT) - end - return data + NDTensors.randn!(rng, data) end function generic_zeros(arraytype::Type{<:AbstractArray}, dim::Integer=0)