Skip to content

Commit

Permalink
implement truncation_error keyword arg for truncate!
Browse files Browse the repository at this point in the history
  • Loading branch information
NuclearPowerNerd committed Nov 20, 2024
1 parent d7f8ca3 commit 5c887ce
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/abstractmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1672,13 +1672,18 @@ provided as keyword arguments.
Keyword arguments:
* `site_range`=1:N - only truncate the MPS bonds between these sites
* `truncation_error` - if provided, will store the truncation error from all SVDs performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_error = Ref{Float64}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_error[] = 0.0`).
"""
function truncate!(M::AbstractMPS; alg="frobenius", kwargs...)
return truncate!(Algorithm(alg), M; kwargs...)
end

function truncate!(
::Algorithm"frobenius", M::AbstractMPS; site_range=1:length(M), kwargs...
::Algorithm"frobenius",
M::AbstractMPS;
site_range=1:length(M),
truncation_error=nothing,
kwargs...,
)
N = length(M)

Expand All @@ -1690,7 +1695,10 @@ function truncate!(
for j in reverse((first(site_range) + 1):last(site_range))
rinds = uniqueinds(M[j], M[j - 1])
ltags = tags(commonind(M[j], M[j - 1]))
U, S, V = svd(M[j], rinds; lefttags=ltags, kwargs...)
U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...)
if !isnothing(truncation_error)
truncation_error[] += spec.truncerr
end
M[j] = U
M[j - 1] *= (S * V)
setrightlim!(M, j)
Expand Down
10 changes: 10 additions & 0 deletions test/base/test_mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,16 @@ end
truncate!(M; site_range=3:7, maxdim=2)
@test linkdims(M) == [2, 4, 2, 2, 2, 2, 8, 4, 2]
end

@testset "truncate! with truncation_error" begin
M = basicRandomMPS(10; dim=10)
truncation_error = Ref{Float64}()
truncation_error[] = 0.0
truncate!(M, maxdim=3, cutoff=1E-3, truncation_error=truncation_error)
@test truncation_error[] > 0.0
end


end

@testset "Other MPS methods" begin
Expand Down

0 comments on commit 5c887ce

Please sign in to comment.