Skip to content

Commit

Permalink
Improve truncate and CheckSVDDone
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 23, 2023
1 parent 28b9e5a commit f56a3b3
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 21 deletions.
4 changes: 2 additions & 2 deletions NDTensors/src/blocksparse/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}

dropblocks = Int[]
if truncate
truncerr, docut = truncate!(d; kwargs...)
d, truncerr, docut = truncate!!(d; kwargs...)
for n in 1:nnzblocks(T)
blockdim = _truncated_blockdim(
Ss[n], docut; min_blockdim, singular_values=true, truncate
Expand Down Expand Up @@ -237,7 +237,7 @@ function eigen(
sort!(d; rev=true, by=abs)

if truncate
truncerr, docut = truncate!(d; kwargs...)
d, truncerr, docut = truncate!!(d; kwargs...)
for n in 1:nnzblocks(T)
blockdim = _truncated_blockdim(Ds[n], docut)
if blockdim == 0
Expand Down
18 changes: 6 additions & 12 deletions NDTensors/src/linearalgebra/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,9 @@ function svd(T::DenseTensor{ElT,2,IndsT}; kwargs...) where {ElT,IndsT}

P = MS .^ 2
if truncate
P_cpu = NDTensors.cpu(P)
truncerr, _ = truncate!(
P_cpu; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
P, truncerr, _ = truncate!!(
P; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
)
P = adapt(typeof(P), P_cpu)
else
truncerr = 0.0
end
Expand Down Expand Up @@ -240,11 +238,9 @@ function eigen(
VM = VM[:, p]

if truncate
DM_cpu = NDTensors.cpu(DM)
truncerr, _ = truncate!(
DM_cpu; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
DM, truncerr, _ = truncate!!(
DM; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
)
DM = adapt(typeof(DM), DM_cpu)
dD = length(DM)
if dD < size(VM, 2)
VM = VM[:, 1:dD]
Expand Down Expand Up @@ -359,11 +355,9 @@ function eigen(
#VM = VM[:,p]

if truncate
DM_cpu = NDTensors.cpu(DM)
truncerr, _ = truncate!(
DM_cpu; maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
DM, truncerr, _ = truncate!!(
DM; maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
)
DM = adapt(typeof(DM), DM_cpu)
dD = length(DM)
if dD < size(VM, 2)
VM = VM[:, 1:dD]
Expand Down
15 changes: 11 additions & 4 deletions NDTensors/src/linearalgebra/svd.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@

function checkSVDDone(S::AbstractArray, thresh::Float64)
# Convert to CPU to avoid slow scalar indexing
# on GPU.
S = cpu(S)
return checkSVDDone(leaf_parenttype(S), S, thresh)
end

# CPU version.
function checkSVDDone(::Type{<:Array}, S::AbstractArray, thresh::Float64)
N = length(S)
(N <= 1 || thresh < 0.0) && return (true, 1)
S1t = S[1] * thresh
Expand All @@ -17,6 +18,12 @@ function checkSVDDone(S::AbstractArray, thresh::Float64)
return (false, start)
end

# Convert to CPU to avoid slow scalar indexing
# on GPU.
function checkSVDDone(::Type{<:AbstractArray}, S::AbstractArray, thresh::Float64)
return checkSVDDone(Array, cpu(S), thresh)
end

function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int=2)
Mr, Mc = size(M)
if Mr > Mc
Expand Down
24 changes: 21 additions & 3 deletions NDTensors/src/truncate.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
export truncate!

function truncate!(P::AbstractVector{ElT}; kwargs...)::Tuple{ElT,ElT} where {ElT}
cutoff::Union{Nothing,ElT} = get(kwargs, :cutoff, zero(ElT))
function truncate!!(P::AbstractArray; kwargs...)
return truncate!!(leaf_parenttype(P), P; kwargs...)
end

# CPU version.
function truncate!!(::Type{<:Array}, P::AbstractArray; kwargs...)
P, truncerr, docut = truncate!(P; kwargs...)
return P, truncerr, docut
end

# GPU fallback version, convert to CPU.
function truncate!!(::Type{<:AbstractArray}, P::AbstractArray; kwargs...)
P_cpu = cpu(P)
P_cpu, truncerr, docut = truncate!(P_cpu; kwargs...)
P = adapt(leaf_parenttype(P), P_cpu)
return P, truncerr, docut
end

# CPU implementation.
function truncate!(P::AbstractVector{ElT}; cutoff=zero(eltype(P)), kwargs...) where {ElT}
if isnothing(cutoff)
cutoff = typemin(ElT)
end
Expand Down Expand Up @@ -92,5 +110,5 @@ function truncate!(P::AbstractVector{ElT}; kwargs...)::Tuple{ElT,ElT} where {ElT
s < 0 && (P .*= s)
resize!(P, n)

return truncerr, docut
return P, truncerr, docut
end

0 comments on commit f56a3b3

Please sign in to comment.