Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Fix scalar indexing issue for Diag broadcast on GPU #1497

Merged
merged 49 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3e77951
Working on GPU diag problem
kmp5VT Jun 13, 2024
721d0f0
format
kmp5VT Jun 13, 2024
9e222ed
Use data over array
kmp5VT Jun 13, 2024
7e95335
remove adapt
kmp5VT Jun 13, 2024
aef7db6
Make expose permutedims for diagtensor to fix cpu error
kmp5VT Jun 13, 2024
f8a80c9
format
kmp5VT Jun 13, 2024
c3df57f
Merge branch 'main' into kmp5/debug/issue_1482
kmp5VT Jun 13, 2024
b61b4de
Merge branch 'main' into kmp5/debug/issue_1482
kmp5VT Jun 13, 2024
f96f322
Remove permutedims function
kmp5VT Jun 13, 2024
7807882
Try to make GPUs more supported by Diag
kmp5VT Jun 13, 2024
8f01962
Merge branch 'main' into kmp5/debug/issue_1482
kmp5VT Jun 13, 2024
6ef078f
Add a comment with a link to the bug in Metal.jl
kmp5VT Jun 13, 2024
c0ad602
Remove unused line (A request from Miles)
kmp5VT Jun 13, 2024
b3892df
Fix diag function for GPU code
kmp5VT Jun 13, 2024
cc24e62
Make diagview functions for Tensor types
kmp5VT Jun 13, 2024
2af30bf
Remove unecessary function
kmp5VT Jun 13, 2024
9f555bc
Update permutedim functions
kmp5VT Jun 13, 2024
a3c13c6
return fill for diagview uniformdiagtensor
kmp5VT Jun 13, 2024
41bf2af
remove uniformdiag definition
kmp5VT Jun 13, 2024
7d1ef42
remove comment
kmp5VT Jun 13, 2024
ad4d96c
Move dense diagview to densetensor
kmp5VT Jun 13, 2024
06d9503
typo
kmp5VT Jun 13, 2024
2140eb4
use diagview over data
kmp5VT Jun 13, 2024
51323fc
remove diaview
kmp5VT Jun 13, 2024
dcf778e
Merge branch 'main' into kmp5/debug/issue_1482
kmp5VT Jun 14, 2024
d8cb8bd
array necessary here because .= fails for CUDA gpu
kmp5VT Jun 14, 2024
1907a47
Merge branch 'main' into kmp5/debug/issue_1482
kmp5VT Jun 14, 2024
4f5e96c
Working on diag function
kmp5VT Jun 17, 2024
b0f1535
Attempt to update diag using expose
kmp5VT Jun 17, 2024
80a9097
Merge branch 'main' into kmp5/debug/issue_1482
kmp5VT Jun 17, 2024
1fac291
Update if statement
kmp5VT Jun 17, 2024
7758242
Merge branch 'main' into kmp5/debug/issue_1482
kmp5VT Jun 19, 2024
7e2ce77
Move map_diag! to NDTensors
kmp5VT Jun 19, 2024
6ed349a
Don't use NDTensors map_diag yet
kmp5VT Jun 19, 2024
5d04949
make map_diag expose based and create GPU version for blocksparse wit…
kmp5VT Jun 19, 2024
60f2abc
format
kmp5VT Jun 19, 2024
9433798
remove expose
kmp5VT Jun 19, 2024
3df285e
simplify dense definition for DiagTensor
kmp5VT Jun 19, 2024
8e49288
remove unused code
kmp5VT Jun 19, 2024
caba6a4
Rename diag to blocksparsetensor.jl
kmp5VT Jun 20, 2024
ea52a82
Try forcing Pkg to update the registry to fix ci issue
kmp5VT Jun 20, 2024
a48ddb0
Use approx because of numerical noise
kmp5VT Jun 20, 2024
b95c01e
Merge branch 'main' into kmp5/debug/issue_1482
kmp5VT Jun 20, 2024
f5c2d39
Add map_diag tests
kmp5VT Jun 20, 2024
d0dcb4d
format
kmp5VT Jun 20, 2024
ae19547
remove w
kmp5VT Jun 20, 2024
b2b0b19
Format
mtfishman Jun 20, 2024
255f66e
Format
mtfishman Jun 20, 2024
ab0728f
Format
mtfishman Jun 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions NDTensors/src/diag/diagtensor.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using .DiagonalArrays: diaglength
using .DiagonalArrays: diaglength, diagview

const DiagTensor{ElT,N,StoreT,IndsT} = Tensor{ElT,N,StoreT,IndsT} where {StoreT<:Diag}
const NonuniformDiagTensor{ElT,N,StoreT,IndsT} =
Expand All @@ -9,9 +9,7 @@ const UniformDiagTensor{ElT,N,StoreT,IndsT} =
function diag(tensor::DiagTensor)
tensor_diag = NDTensors.similar(dense(typeof(tensor)), (diaglength(tensor),))
# TODO: Define `eachdiagindex`.
for j in 1:diaglength(tensor)
tensor_diag[j] = getdiagindex(tensor, j)
end
diagview(array(tensor_diag)) .= diagview(array(tensor))
return tensor_diag
end

Expand Down Expand Up @@ -145,16 +143,14 @@ function permutedims!(
f::Function=(r, t) -> t,
) where {N}
# TODO: check that inds(R)==permute(inds(T),perm)?
for i in 1:diaglength(R)
@inbounds setdiagindex!(R, f(getdiagindex(R, i), getdiagindex(T, i)), i)
end
diagview(data(R)) .= f.(diagview(data(R)), diagview(data(T)))
return R
end

function permutedims(
T::DiagTensor{<:Number,N}, perm::NTuple{N,Int}, f::Function=identity
) where {N}
R = NDTensors.similar(T, permute(inds(T), perm))
R = NDTensors.similar(T)
g(r, t) = f(t)
permutedims!(R, T, perm, g)
return R
Expand Down Expand Up @@ -193,9 +189,9 @@ end
function permutedims!(
R::DenseTensor{ElR,N}, T::DiagTensor{ElT,N}, perm::NTuple{N,Int}, f::Function=(r, t) -> t
) where {ElR,ElT,N}
for i in 1:diaglength(T)
@inbounds setdiagindex!(R, f(getdiagindex(R, i), getdiagindex(T, i)), i)
end
rview = diagview(array(R))
tview = diagview(T)
rview .= f.(rview, tview)
return R
end

Expand Down
14 changes: 11 additions & 3 deletions NDTensors/test/test_diag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,30 @@ using LinearAlgebra: dot
@test complex(D) == Diag(one(complex(elt)))
@test similar(D) == Diag(0.0)

D = Tensor(Diag(1), (2, 2))
D = dev(Tensor(Diag(1), (2, 2)))
@test norm(D) == √2
d = 3
vr = rand(elt, d)
D = dev(tensor(Diag(vr), (d, d)))
Da = Array(D)
Dm = Matrix(D)
Da = permutedims(D, (2, 1))
@allowscalar begin
@test Da == NDTensors.LinearAlgebra.diagm(0 => vr)
@test Da == NDTensors.LinearAlgebra.diagm(0 => vr)

## TODO Currently this permutedims requires scalar indexing on GPU.
Da = permutedims(D, (2, 1))
@test Da == D
end

if (dev == NDTensors.mtl && elt != ComplexF32)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
S = permutedims(dev(D), (1, 2), sqrt)
@allowscalar begin
for i in 1:diaglength(S)
@test S[i, i] == sqrt(D[i, i])
end
end
end

# Regression test for https://github.com/ITensor/ITensors.jl/issues/1199
S = dev(tensor(Diag(randn(elt, 2)), (2, 2)))
## This was creating a `Dense{ReshapedArray{Adjoint{Matrix}}}` which, in mul!, was
Expand Down
Loading