Skip to content

Commit

Permalink
import -> using. Can use NDTenosrs randn! instead of another function
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Sep 14, 2023
1 parent a4baeb7 commit 03a9b27
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
16 changes: 8 additions & 8 deletions NDTensors/ext/NDTensorsCUDAExt/fill.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import NDTensors.SetParameters: set_unspecified_parameters, DefaultParameters
import NDTensors: leaf_parenttype
using NDTensors.SetParameters: set_unspecified_parameters, DefaultParameters
using NDTensors: leaf_parenttype

function NDTensors.generic_randn(arraytype::Type{<:CuArray}, dim::Integer=0)
arraytype_specified = set_unspecified_parameters(
leaf_parenttype(arraytype), DefaultParameters()
)
return CUDA.randn(eltype(arraytype_specified), dim)
end
# function NDTensors.generic_randn(arraytype::Type{<:CuArray}, dim::Integer=0)
# arraytype_specified = set_unspecified_parameters(
# leaf_parenttype(arraytype), DefaultParameters()
# )
# return CUDA.randn(eltype(arraytype_specified), dim)
# end

function NDTensors.generic_zeros(arraytype::Type{<:CuArray}, dim::Integer=0)
arraytype_specified = NDTensors.set_unspecified_parameters(
Expand Down
8 changes: 2 additions & 6 deletions NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
function generic_randn(arraytype::Type{<:AbstractArray}, dim::Integer=0)
function generic_randn(arraytype::Type{<:AbstractArray}, dim::Integer=0; rng = Random.default_rng())

Check warning on line 1 in NDTensors/src/abstractarray/fill.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: NDTensors/src/abstractarray/fill.jl:1:-function generic_randn(arraytype::Type{<:AbstractArray}, dim::Integer=0; rng = Random.default_rng()) NDTensors/src/abstractarray/fill.jl:1:+function generic_randn( NDTensors/src/abstractarray/fill.jl:2:+ arraytype::Type{<:AbstractArray}, dim::Integer=0; rng=Random.default_rng() NDTensors/src/abstractarray/fill.jl:3:+)
arraytype_specified = set_unspecified_parameters(
leaf_parenttype(arraytype), DefaultParameters()
)
data = similar(arraytype_specified, dim)
ElT = eltype(data)
for i in 1:length(data)
data[i] = randn(ElT)
end
return data
NDTensors.randn!(rng, data)

Check warning on line 6 in NDTensors/src/abstractarray/fill.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: NDTensors/src/abstractarray/fill.jl:6:- NDTensors.randn!(rng, data) NDTensors/src/abstractarray/fill.jl:8:+ return NDTensors.randn!(rng, data)
end

function generic_zeros(arraytype::Type{<:AbstractArray}, dim::Integer=0)
Expand Down

0 comments on commit 03a9b27

Please sign in to comment.