Skip to content

Commit

Permalink
Attempt to update diag using expose
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Jun 17, 2024
1 parent 4f5e96c commit b0f1535
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
module NDTensorsGPUArraysCoreExt
include("contract.jl")
include("diag.jl")
end
11 changes: 11 additions & 0 deletions NDTensors/ext/NDTensorsGPUArraysCoreExt/diag.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using GPUArraysCore: AbstractGPUArray
using NDTensors: NDTensors, BlockSparseTensor, dense, diag
using NDTensors.Expose: Exposed, unexpose

## TODO to circumvent issues with blocksparse and scalar indexing
## convert blocksparse GPU tensors to dense tensors and call diag
## copying will probably have some impact on timing but this code
## currently isn't used in the main code, just in tests.
function NDTensors.diag(ETensor::Exposed{<:AbstractGPUArray,<:BlockSparseTensor})
return diag(dense(unexpose(ETensor)))
end
4 changes: 2 additions & 2 deletions NDTensors/src/blocksparse/blocksparsetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,13 @@ function dense(T::TensorT) where {TensorT<:BlockSparseTensor}
return tensor(Dense(r), inds(T))
end

function diag(ETensor::Exposed{<:AbstractArray, BlockSparseTensor})
function diag(ETensor::Exposed{<:AbstractArray,<:BlockSparseTensor})
tensor = unexpose(ETensor)
tensordiag = NDTensors.similar(
dense(typeof(tensor)), eltype(tensor), (diaglength(tensor),)
)
for j in 1:diaglength(tensor)
@inbounds tensor_diag[j] = getdiagindex(tensor, j)
@inbounds tensordiag[j] = getdiagindex(tensor, j)
end
return tensordiag
end
Expand Down
6 changes: 3 additions & 3 deletions NDTensors/src/tensor/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,16 @@ using .Expose: Exposed, expose, unexpose
# block sparse vector instead of dense.
diag(tensor::Tensor) = diag(expose(tensor))

function diag(Etensor::Exposed)
tensor = unexpose(Etensor)
function diag(ETensor::Exposed)
tensor = unexpose(ETensor)
## d = NDTensors.similar(T, ElT, (diaglength(T),))
tensordiag = NDTensors.similar(
dense(typeof(tensor)), eltype(tensor), (diaglength(tensor),)
)
array(tensordiag) .= diagview(tensor)
return tensordiag
end

"""
setdiagindex!
Expand Down

0 comments on commit b0f1535

Please sign in to comment.