From 7b4d4ef7515cd16005fdfde6568b481c11e505ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Fri, 16 Feb 2024 11:12:08 +0100 Subject: [PATCH] Fix `canonize` function and enhance index semantics (#8) * Fix rightindex and leftindex logic * Fix canonize! function and enhance syntax * Fix typo and add tests * Enhance tests * Update code for Periodic boundary * Fix leftsite and rightsite function * Add Site testset * Update syntax --- src/Ansatz/Chain.jl | 55 +++++++++++++++++++----------- src/Qrochet.jl | 2 +- test/Ansatz/Chain_test.jl | 72 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 20 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 8f88128..6c3beec 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -112,57 +112,74 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end +leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) +function leftsite(::Open, tn::Chain, site::Site) + site.id ∉ range(2, length(sites(tn))) && throw(ArgumentError("Invalid site $site")) + Site(site.id - 1) +end +leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id - 1, length(sites(tn)))) + +rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) +function rightsite(::Open, tn::Chain, site::Site) + site.id ∉ range(1, length(sites(tn))-1) && throw(ArgumentError("Invalid site $site")) + Site(site.id + 1) +end +rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id + 1, length(sites(tn)))) + leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) -leftindex(::Periodic, tn::Chain, site::Site) = (select(tn, :tensor, site)|>inds)[end-1] -function leftindex(::Open, tn::Chain, site::Site) +function leftindex(::Union{Open, Periodic}, tn::Chain, site::Site) if site == site"1" nothing - elseif site == Site(nsites(tn)) # TODO review - (select(tn, :tensor, site)|>inds)[end] else - (select(tn, :tensor, site)|>inds)[end-1] + (select(tn, :tensor, site)|>inds) ∩ (select(tn, :tensor, leftsite(tn, site))|>inds) |> only end end rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site) -rightindex(::Periodic, tn::Chain, site::Site) = (select(tn, :tensor, site)|>inds)[end] -function rightindex(::Open, tn::Chain, site::Site) +function rightindex(::Union{Open, Periodic}, tn::Chain, site::Site) if site == Site(nsites(tn)) # TODO review nothing else - (select(tn, :tensor, site)|>inds)[end] + (select(tn, :tensor, site)|>inds) ∩ (select(tn, :tensor, rightsite(tn, site))|>inds) |> only end end +canonize(tn::Chain, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...) canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...) -# NOTE spectral weights are stored in a vector connected to the now virtual hyperindex! -function canonize!(::Open, tn::Chain, site::Site; direction::Symbol) +# NOTE: in mode == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! +function canonize!(::Open, tn::Chain, site::Site; direction::Symbol, mode = :qr) left_inds = Symbol[] right_inds = Symbol[] virtualind = if direction === :left - site == Site(1) && throw(ArgumentError("Cannot left-canonize left-most tensor")) - push!(left_inds, leftindex(tn, site)) - - site == Site(nsites(tn)) || push!(right_inds, rightindex(tn, site)) - push!(right_inds, Quantum(tn)[site]) - - only(left_inds) - elseif direction === :right site == Site(nsites(tn)) && throw(ArgumentError("Cannot right-canonize right-most tensor")) push!(right_inds, rightindex(tn, site)) site == Site(1) || push!(left_inds, leftindex(tn, site)) push!(left_inds, Quantum(tn)[site]) + only(right_inds) + elseif direction === :right + site == Site(1) && throw(ArgumentError("Cannot left-canonize left-most tensor")) + push!(right_inds, leftindex(tn, site)) + + site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site)) + push!(left_inds, Quantum(tn)[site]) + only(right_inds) else throw(ArgumentError("Unknown direction=:$direction")) end tmpind = gensym(:tmp) - qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind) + if mode == :qr + qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind) + elseif mode == :svd + svd!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind) + else + throw(ArgumentError("Unknown mode=:$mode")) + end contract!(TensorNetwork(tn), virtualind) replace!(TensorNetwork(tn), tmpind => virtualind) diff --git a/src/Qrochet.jl b/src/Qrochet.jl index d0387a0..72f4eaa 100644 --- a/src/Qrochet.jl +++ b/src/Qrochet.jl @@ -17,7 +17,7 @@ export Product include("Ansatz/Chain.jl") export Chain export MPS, pMPS, MPO, pMPO -export leftindex, rightindex, canonize! +export leftindex, rightindex, canonize, canonize! # reexports from Tenet using Tenet diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 2112dd6..af94f79 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -26,4 +26,76 @@ @test noutputs(qtn) == 3 @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) @test boundary(qtn) == Open() + + @testset "Site" begin + using Qrochet: leftsite, rightsite + qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) + + @test leftsite(qtn, Site(1)) == Site(3) + @test leftsite(qtn, Site(2)) == Site(1) + @test leftsite(qtn, Site(3)) == Site(2) + + @test rightsite(qtn, Site(1)) == Site(2) + @test rightsite(qtn, Site(2)) == Site(3) + @test rightsite(qtn, Site(3)) == Site(1) + + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test_throws ArgumentError leftsite(qtn, Site(1)) + @test_throws ArgumentError rightsite(qtn, Site(3)) + + @test leftsite(qtn, Site(2)) == Site(1) + @test leftsite(qtn, Site(3)) == Site(2) + + @test rightsite(qtn, Site(2)) == Site(3) + @test rightsite(qtn, Site(1)) == Site(2) + end + + @testset "canonize" begin + using Tenet + + function is_left_canonical(qtn, s::Site) + label_r = rightindex(qtn, s) + A = select(qtn, :tensor, s) + try + contracted = contract(A, replace(conj(A), label_r => :new_ind_name)) + return isapprox(contracted, Matrix{Float64}(I, size(A, label_r), size(A, label_r)), atol=1e-12) + catch + return false + end + end + + function is_right_canonical(qtn, s::Site) + label_l = leftindex(qtn, s) + A = select(qtn, :tensor, s) + try + contracted = contract(A, replace(conj(A), label_l => :new_ind_name)) + return isapprox(contracted, Matrix{Float64}(I, size(A, label_l), size(A, label_l)), atol=1e-12) + catch + return false + end + end + + qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)]) + + @test_throws ArgumentError canonize!(qtn, Site(1); direction=:right) + @test_throws ArgumentError canonize!(qtn, Site(3); direction=:left) + + for mode in [:qr, :svd] + for i in 1:length(sites(qtn)) + if i != 1 + canonized = canonize(qtn, Site(i); direction=:right, mode=mode) + @test is_right_canonical(canonized, Site(i)) + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn))) + elseif i != length(sites(qtn)) + canonized = canonize(qtn, Site(i); direction=:left, mode=mode) + @test is_left_canonical(canonized, Site(i)) + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn))) + end + end + end + + # Ensure that svd creates a new tensor + @test length(tensors(canonize(qtn, Site(2); direction=:right, mode=:svd))) == 4 + end end