Skip to content

Commit

Permalink
[NDTensors] Avoid more scalar indexing operations in block sparse GPU…
Browse files Browse the repository at this point in the history
… code (#1217)
  • Loading branch information
mtfishman authored Oct 25, 2023
1 parent 871e59d commit cea22f9
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 20 deletions.
4 changes: 3 additions & 1 deletion NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module NDTensorsMetalExt

using Adapt
using Functors
using LinearAlgebra: LinearAlgebra
using LinearAlgebra: LinearAlgebra, Transpose, mul!
using NDTensors
using NDTensors.SetParameters

Expand All @@ -18,5 +18,7 @@ include("set_types.jl")
include("indexing.jl")
include("linearalgebra.jl")
include("copyto.jl")
include("append.jl")
include("permutedims.jl")
include("mul.jl")
end
5 changes: 5 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/append.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# This circumvents an issues that `MtlArray` can't call `resize!`.
# TODO: Raise an issue with Metal.jl.
function NDTensors.append!!(::Type{<:MtlArray}, collection, collections...)
return vcat(collection, collections...)
end
15 changes: 15 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This was calling generic matrix multiplication.
# TODO: Raise an issue with `Metal.jl`.
function NDTensors.mul!!(
::Type{<:MtlArray},
CM::Transpose,
::Type{<:MtlArray},
AM::AbstractMatrix,
::Type{<:MtlArray},
BM::AbstractMatrix,
α,
β,
)
mul!(transpose(CM), transpose(BM), transpose(AM), α, β)
return CM
end
1 change: 1 addition & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ include("abstractarray/iscu.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/copyto.jl")
include("abstractarray/append.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("abstractarray/mul.jl")
Expand Down
10 changes: 10 additions & 0 deletions NDTensors/src/abstractarray/append.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# NDTensors.append!
# Used to circumvent issues with some GPU backends like Metal
# not supporting `resize!`.
function append!!(collection, collections...)
return append!!(leaf_parenttype(collection), collection, collections...)
end

function append!!(::Type, collection, collections...)
return append!(collection, collections...)
end
7 changes: 5 additions & 2 deletions NDTensors/src/blocksparse/blocksparsetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,10 @@ function uncombine(
#copyto!(Rb,Tb)

if length(Tb) == 1
Rb[] = Tb[]
# Call `cpu` to avoid allowscalar error on GPU.
# TODO: Replace with `@allowscalar`, requires adding
# `GPUArraysCore.jl` as a dependency.
Rb[] = cpu(Tb)[]
else
# XXX: this used to be:
# Rbₐᵣ = ReshapedArray(parent(Rbₐ), size(Tb), ())
Expand Down Expand Up @@ -712,7 +715,7 @@ function permutedims!!(
## copyto!(data(RR), data(R))

if new_nnz > nnz(RR)
dataRR = append!(data(RR), zeros(new_nnz - nnz(RR)))
dataRR = append!!(data(RR), generic_zeros(leaf_parenttype(R), new_nnz - nnz(RR)))
RR = Tensor(BlockSparse(dataRR, bofsRR), inds(RR))
end

Expand Down
24 changes: 13 additions & 11 deletions NDTensors/src/blocksparse/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ const DiagMatrix{ElT,StoreT,IndsT} = DiagTensor{ElT,2,StoreT,IndsT}
function _truncated_blockdim(
S::DiagMatrix, docut::Real; singular_values=false, truncate=true, min_blockdim=0
)
# TODO: Replace `cpu` with `leaf_parenttype` dispatch.
S = cpu(S)
full_dim = diaglength(S)
!truncate && return full_dim
min_blockdim = min(min_blockdim, full_dim)
Expand Down Expand Up @@ -84,7 +86,8 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}
if blockdim == 0
push!(dropblocks, n)
else
Strunc = tensor(Diag(storage(Ss[n])[1:blockdim]), (blockdim, blockdim))
# TODO: Replace call to `data` with `diagview`.
Strunc = tensor(Diag(data(Ss[n])[1:blockdim]), (blockdim, blockdim))
Us[n] = Us[n][1:dim(Us[n], 1), 1:blockdim]
Ss[n] = Strunc
Vs[n] = Vs[n][1:dim(Vs[n], 1), 1:blockdim]
Expand Down Expand Up @@ -177,9 +180,8 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}
copyto!(blockview(U, blockU), Ub)

blockviewS = blockview(S, blockS)
for i in 1:diaglength(Sb)
setdiagindex!(blockviewS, getdiagindex(Sb, i), i)
end
# TODO: Replace `data` with `diagview`.
copyto!(data(blockviewS), data(Sb))

#<fermions>
sV = left_arrow_sign(vind, blockV[2])
Expand Down Expand Up @@ -243,7 +245,8 @@ function eigen(
if blockdim == 0
push!(dropblocks, n)
else
Dtrunc = tensor(Diag(storage(Ds[n])[1:blockdim]), (blockdim, blockdim))
# TODO: Replace call to `data` with `diagview`.
Dtrunc = tensor(Diag(data(Ds[n])[1:blockdim]), (blockdim, blockdim))
Ds[n] = Dtrunc
new_size = (dim(Vs[n], 1), blockdim)
new_data = array(Vs[n])[1:new_size[1], 1:new_size[2]]
Expand Down Expand Up @@ -311,12 +314,11 @@ function eigen(

blockD = nzblocksD[n]
blockviewD = blockview(D, blockD)
for i in 1:diaglength(Db)
setdiagindex!(blockviewD, getdiagindex(Db, i), i)
end
# TODO: Replace `data` with `diagview`.
copyto!(data(blockviewD), data(Db))

blockV = nzblocksV[n]
blockview(V, blockV) .= Vb
copyto!(blockview(V, blockV), Vb)
end

return D, V, Spectrum(d, truncerr)
Expand Down Expand Up @@ -380,8 +382,8 @@ function qx(qx::Function, T::BlockSparseTensor{<:Any,2}; kwargs...)
X = BlockSparseTensor(leaf_parenttype(T), undef, nzblocksX, indsX)

for n in 1:nnzblocksT
blockview(Q, nzblocksQ[n]) .= Qs[n]
blockview(X, nzblocksX[n]) .= Xs[n]
copyto!(blockview(Q, nzblocksQ[n]), Qs[n])
copyto!(blockview(X, nzblocksX[n]), Xs[n])
end

Q = adapt(leaf_parenttype(T), Q)
Expand Down
8 changes: 4 additions & 4 deletions NDTensors/src/dense/tensoralgebra/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ function _contract_scalar!(
β=zero(ElR),
) where {ElR}
if iszero(β)
R[1] = α * T1 * T2
R[] = α * T1 * T2
elseif iszero(α)
R[1] = β * R[1]
R[] = β * R[]
else
R[1] = α * T1 * T2 + β * R[1]
R[] = α * T1 * T2 + β * R[]
end
return R
end
Expand Down Expand Up @@ -150,7 +150,7 @@ function _contract_scalar!(
β=zero(ElR),
) where {ElR}
if nnz(T1) == nnz(T2) == 1
_contract_scalar!(R, labelsR, T1[1], labelsT1, T2[1], labelsT2, α, β)
_contract_scalar!(R, labelsR, T1[], labelsT1, T2[], labelsT2, α, β)
else
_contract_scalar_maybe_perm!(R, labelsR, T1, labelsT1, T2, labelsT2, α, β)
end
Expand Down
3 changes: 2 additions & 1 deletion NDTensors/src/linearalgebra/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ function eigen(
DM, VM = eigen(matrixT)

# Sort by largest to smallest eigenvalues
p = sortperm(DM; rev=true, by=abs)
# TODO: Replace `cpu` with `leaf_parenttype` dispatch.
p = sortperm(cpu(DM); rev=true, by=abs)
DM = DM[p]
VM = VM[:, p]

Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end
# GPU fallback version, convert to CPU.
function truncate!!(::Type{<:AbstractArray}, P::AbstractArray; kwargs...)
P_cpu = cpu(P)
P_cpu, truncerr, docut = truncate!(P_cpu; kwargs...)
truncerr, docut = truncate!(P_cpu; kwargs...)
P = adapt(leaf_parenttype(P), P_cpu)
return P, truncerr, docut
end
Expand Down

0 comments on commit cea22f9

Please sign in to comment.