Skip to content

Commit

Permalink
Add iscu function
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Oct 8, 2023
1 parent 0f310e9 commit 46e00ce
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ else
end

include("imports.jl")

iscu(::Type{<:CuArray}) = true
include("set_types.jl")
include("adapt.jl")
include("linearalgebra.jl")
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsCUDAExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import NDTensors: cu, set_ndims, set_eltype, set_eltype_if_unspecified, similartype
import NDTensors:
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter, iscu

Check warning on line 4 in NDTensors/ext/NDTensorsCUDAExt/imports.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: NDTensors/ext/NDTensorsCUDAExt/imports.jl:4:-import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter, iscu NDTensors/ext/NDTensorsCUDAExt/imports.jl:4:+import NDTensors.SetParameters: NDTensors/ext/NDTensorsCUDAExt/imports.jl:5:+ nparameters, get_parameter, set_parameter, default_parameter, iscu

import .CUDA: CuArrayAdaptor
3 changes: 3 additions & 0 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ function similar(arraytype::Type{<:AbstractArray}, dims::Tuple)
return similartype(arraytype, shape)(undef, NDTensors.to_shape(arraytype, shape))
end

# For when there are CUArray specific issues inline
iscu(A::AbstractArray) = iscu(leaf_parenttype(A))
iscu(::Type{<:AbstractArray}) = false
# This function actually allocates the data.
# Catches conversions of dimensions specified by ranges
# dimensions specified by integers with `Base.to_shape`.
Expand Down

0 comments on commit 46e00ce

Please sign in to comment.