diff --git a/src/abstractmps.jl b/src/abstractmps.jl index 4670be5..1580efe 100644 --- a/src/abstractmps.jl +++ b/src/abstractmps.jl @@ -1672,13 +1672,18 @@ provided as keyword arguments. Keyword arguments: * `site_range`=1:N - only truncate the MPS bonds between these sites +* `truncation_error` - if provided, will store the truncation error from all SVDs performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_error = Ref{Float64}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_error[] = 0.0`). """ function truncate!(M::AbstractMPS; alg="frobenius", kwargs...) return truncate!(Algorithm(alg), M; kwargs...) end function truncate!( - ::Algorithm"frobenius", M::AbstractMPS; site_range=1:length(M), kwargs... + ::Algorithm"frobenius", + M::AbstractMPS; + site_range=1:length(M), + truncation_error=nothing, + kwargs..., ) N = length(M) @@ -1690,7 +1695,10 @@ function truncate!( for j in reverse((first(site_range) + 1):last(site_range)) rinds = uniqueinds(M[j], M[j - 1]) ltags = tags(commonind(M[j], M[j - 1])) - U, S, V = svd(M[j], rinds; lefttags=ltags, kwargs...) + U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...) + if !isnothing(truncation_error) + truncation_error[] += spec.truncerr + end M[j] = U M[j - 1] *= (S * V) setrightlim!(M, j) diff --git a/test/base/test_mps.jl b/test/base/test_mps.jl index 65eed0e..ff75ef1 100644 --- a/test/base/test_mps.jl +++ b/test/base/test_mps.jl @@ -755,6 +755,16 @@ end truncate!(M; site_range=3:7, maxdim=2) @test linkdims(M) == [2, 4, 2, 2, 2, 2, 8, 4, 2] end + + @testset "truncate! with truncation_error" begin + M = basicRandomMPS(10; dim=10) + truncation_error = Ref{Float64}() + truncation_error[] = 0.0 + truncate!(M, maxdim=3, cutoff=1E-3, truncation_error=truncation_error) + @test truncation_error[] > 0.0 + end + + end @testset "Other MPS methods" begin