Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Circumvent scalar indexing to improve GPU performance #1216

Closed
wants to merge 17 commits into from
Closed
4 changes: 2 additions & 2 deletions NDTensors/src/blocksparse/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function LinearAlgebra.svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}

dropblocks = Int[]
if truncate
truncerr, docut = truncate!(d; kwargs...)
truncerr, docut = truncate!!(d; kwargs...)
for n in 1:nnzblocks(T)
blockdim = _truncated_blockdim(
Ss[n], docut; min_blockdim, singular_values=true, truncate
Expand Down Expand Up @@ -237,7 +237,7 @@ function LinearAlgebra.eigen(
sort!(d; rev=true, by=abs)

if truncate
truncerr, docut = truncate!(d; kwargs...)
truncerr, docut = truncate!!(d; kwargs...)
for n in 1:nnzblocks(T)
blockdim = _truncated_blockdim(Ds[n], docut)
if blockdim == 0
Expand Down
11 changes: 10 additions & 1 deletion NDTensors/src/dense/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,16 @@ Dense(::Type{ElT}) where {ElT} = Dense{ElT}()
setdata(D::Dense, ndata) = Dense(ndata)
setdata(storagetype::Type{<:Dense}, data) = Dense(data)

copy(D::Dense) = Dense(copy(data(D)))
## There is an GPU arrays which are ReshapedArray can
## fail when trying to copy (fail meaning they call get_index which is slow) so this forces a fix.
## TODO make a better implementation
function copy(D::Dense)
d = data(D)
if d isa Base.ReshapedArray
return Dense(copy(parent(d)))
end
return Dense(copy(data(D)))
end

function Base.real(T::Type{<:Dense})
return set_datatype(T, similartype(datatype(T), real(eltype(T))))
Expand Down
13 changes: 8 additions & 5 deletions NDTensors/src/linearalgebra/svd.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@

## TODO here it looks at the elements of S so convert to CPU when on GPU
## Could write this as a GPU impl which just converts S to array. S
## is not used again so we don't need to convert back to GPU.
Comment on lines +2 to +4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion here as above for truncate!!, I think we should use leaf_parenttype dispatch.

function checkSVDDone(S::AbstractArray, thresh::Float64)
N = length(S)
(N <= 1 || thresh < 0.0) && return (true, 1)
S1t = S[1] * thresh
Scpu = NDTensors.cpu(S)
S1t = Scpu[1] * thresh
start = 2
while start <= N
(S[start] < S1t) && break
(Scpu[start] < S1t) && break
start += 1
end
if start >= N
Expand All @@ -32,9 +36,8 @@ function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int=
V = M' * U

V, R = qr_positive(V)
for n in 1:Nd
D[n] = R[n, n]
end
n = size(R)[1]
D = diag(R, (n - Nd))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is doing what you think it is doing, diag(M, k) selects the kth off-diagonal of the matrix.

I went with:

  D[1:Nd] = diag(R)[1:Nd]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, thanks!


(done, start) = checkSVDDone(D, thresh)

Expand Down
21 changes: 20 additions & 1 deletion NDTensors/src/truncate.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
export truncate!
export truncate!!, truncate!

## TODO Here truncate does logical operations of the values in P
## So its more efficient to just make it a CPU vector and
## convert back to GPU

function truncate!!(P::AbstractVector; kwargs...)
return truncate!(leaf_parenttype(P), P; kwargs...)
end

function truncate!(::Type{<:Array}, P; kwargs...)
return truncate!(P; kwargs...)
end

function truncate!(::Type{<:AbstractArray}, P; kwargs...)
P_cpu = cpu(P)
values = truncate!(P_cpu; kwargs...)
P = adapt(leaf_parenttype(P), P_cpu)
return values
end

function truncate!(P::AbstractVector{ElT}; kwargs...)::Tuple{ElT,ElT} where {ElT}
cutoff::Union{Nothing,ElT} = get(kwargs, :cutoff, zero(ElT))
Expand Down
Loading