Skip to content

Commit

Permalink
Merge branch 'kmp5/debug/scalar_indexing' of github.com:kmp5VT/ITenso…
Browse files Browse the repository at this point in the history
…rs.jl into kmp5/debug/scalar_indexing
  • Loading branch information
kmp5VT committed Nov 15, 2023
2 parents 85aee99 + 2650f60 commit 21aa398
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 7 deletions.
2 changes: 0 additions & 2 deletions NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ using CUDA
using CUDA.CUBLAS
using CUDA.CUSOLVER

## TODO I added copyto and permutedims which match the functions in
## NDTensorsMetalExt because I found similar issues in CUDA
include("imports.jl")
include("default_kwargs.jl")
include("copyto.jl")
Expand Down
4 changes: 1 addition & 3 deletions NDTensors/ext/NDTensorsCUDAExt/copyto.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
## IT looks like CuArray suffers from the same issues as MtlArray.
## To fix this subarray copyto problem I copied same code from MetalExt
## This means we can probably write a generic implmenetation for GPUArrays
# Same definition as `MtlArray`.
function Base.copy(src::Exposed{<:CuArray,<:Base.ReshapedArray})
return reshape(copy(parent(src)), size(unexpose(src)))
end
Expand Down
2 changes: 0 additions & 2 deletions NDTensors/ext/NDTensorsCUDAExt/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ function LinearAlgebra.mul!(
return unexpose(CM)
end

## TODO I wasn't sure the best route to go here, if there is a better route than
## copy please let me know!
## Fix issue in CUDA.jl where it cannot distinguish Transpose{Reshape{Adjoint{CuArray}}}
## as a CuArray and calls generic matmul
function LinearAlgebra.mul!(
Expand Down

0 comments on commit 21aa398

Please sign in to comment.