From 837642c2f81fb3ae4c196fa0feca8c7b1cfa4ff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Wed, 19 Jun 2024 13:15:33 +0200 Subject: [PATCH 1/7] Add `order` keyword argument in `Chain` constructor functions (#47) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add order kwarg in Chain constructor * Add tests * Fix minor typos * Add minor fixes in Chain constructor * Enhance tests * Format code * Apply @mofeing suggestions from code review Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> * Remove unnecessary helper function * Apply suggestions from code review Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> * Minor aesthetic updates in code --------- Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/Ansatz/Chain.jl | 95 +++++++++++++++---- test/Ansatz/Chain_test.jl | 193 ++++++++++++++++++++++++++++++++------ 2 files changed, 241 insertions(+), 47 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index f33e2ce..4bc699b 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -30,14 +30,30 @@ function Chain(tn::TensorNetwork, sites, args...; kwargs...) Chain(Quantum(tn, sites), args...; kwargs...) end -function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}) +defaultorder(::Type{Chain}, ::State) = (:o, :l, :r) +defaultorder(::Type{Chain}, ::Operator) = (:o, :i, :l, :r) + +function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State())) @assert all(==(3) ∘ ndims, arrays) "All arrays must have 3 dimensions" + issetequal(order, defaultorder(Chain, State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) n = length(arrays) symbols = [nextindex() for _ in 1:2n] _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]]) + inds = map(order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n+mod1(i, n)] + elseif dir == :l + symbols[n+mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) @@ -45,22 +61,37 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}) +function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State())) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:end-1]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" + issetequal(order, defaultorder(Chain, State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) n = length(arrays) - symbols = [nextindex() for _ in 1:2n-1] + symbols = [nextindex() for _ in 1:2n] _tensors = map(enumerate(arrays)) do (i, array) - if i == 1 - Tensor(array, [symbols[1], symbols[1+n]]) + _order = if i == 1 + filter(x -> x != :l, order) elseif i == n - Tensor(array, [symbols[n], symbols[n+mod1(n - 1, n)]]) + filter(x -> x != :r, order) else - Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]]) + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n+mod1(i, n)] + elseif dir == :l + symbols[n+mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end end + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) @@ -68,14 +99,29 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}) - @assert all(==(4) ∘ ndims, arrays) "All arrays must have 3 dimensions" +function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator())) + @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n] _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]]) + inds = map(order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i+n] + elseif dir == :l + symbols[2n+mod1(i - 1, n)] + elseif dir == :r + symbols[2n+mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) @@ -84,22 +130,39 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}) +function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator())) @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" @assert all(==(4) ∘ ndims, arrays[2:end-1]) "All arrays must have 4 dimensions" @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n-1] _tensors = map(enumerate(arrays)) do (i, array) - if i == 1 - Tensor(array, [symbols[1], symbols[n+1], symbols[1+2n]]) + _order = if i == 1 + filter(x -> x != :l, order) elseif i == n - Tensor(array, [symbols[n], symbols[2n], symbols[2n+mod1(n - 1, n)]]) + filter(x -> x != :r, order) else - Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]]) + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i+n] + elseif dir == :l + symbols[2n+mod1(i - 1, n)] + elseif dir == :r + symbols[2n+mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end end + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 0eb7bcb..7496d85 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -1,35 +1,166 @@ @testset "Chain ansatz" begin - qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) - @test socket(qtn) == State() - @test ninputs(qtn) == 0 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - @test socket(qtn) == State() - @test ninputs(qtn) == 0 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing - - qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) - @test socket(qtn) == Operator() - @test ninputs(qtn) == 3 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) - @test socket(qtn) == Operator() - @test ninputs(qtn) == 3 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + @testset "Periodic boundary" begin + @testset "State" begin + qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3"]) + @test boundary(qtn) == Periodic() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing + + arrays = [rand(2, 1, 4), rand(2, 4, 3), rand(2, 3, 1)] + qtn = Chain(State(), Periodic(), arrays) # Default order (:o, :l, :r) + + @test size(tensors(qtn; at = Site(1))) == (2, 1, 4) + @test size(tensors(qtn; at = Site(2))) == (2, 4, 3) + @test size(tensors(qtn; at = Site(3))) == (2, 3, 1) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + arrays = [permutedims(array, (3, 1, 2)) for array in arrays] # now we have (:r, :o, :l) + qtn = Chain(State(), Periodic(), arrays, order = [:r, :o, :l]) + + @test size(tensors(qtn; at = Site(1))) == (4, 2, 1) + @test size(tensors(qtn; at = Site(2))) == (3, 2, 4) + @test size(tensors(qtn; at = Site(3))) == (1, 2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:nsites(qtn) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + end + end + + @testset "Operator" begin + qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) + @test socket(qtn) == Operator() + @test ninputs(qtn) == 3 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(qtn) == Periodic() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing + + arrays = [rand(2, 4, 1, 3), rand(2, 4, 3, 6), rand(2, 4, 6, 1)] # Default order (:o, :i, :l, :r) + qtn = Chain(Operator(), Periodic(), arrays) + + @test size(tensors(qtn; at = Site(1))) == (2, 4, 1, 3) + @test size(tensors(qtn; at = Site(2))) == (2, 4, 3, 6) + @test size(tensors(qtn; at = Site(3))) == (2, 4, 6, 1) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 + end + + arrays = [permutedims(array, (4, 1, 3, 2)) for array in arrays] # now we have (:r, :o, :l, :i) + qtn = Chain(Operator(), Periodic(), arrays, order = [:r, :o, :l, :i]) + + @test size(tensors(qtn; at = Site(1))) == (3, 2, 1, 4) + @test size(tensors(qtn; at = Site(2))) == (6, 2, 3, 4) + @test size(tensors(qtn; at = Site(3))) == (1, 2, 6, 4) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) !== nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 + end + end + end + + @testset "Open boundary" begin + @testset "State" begin + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3"]) + @test boundary(qtn) == Open() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + + arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] + qtn = Chain(State(), Open(), arrays) # Default order (:o, :l, :r) + + @test size(tensors(qtn; at = Site(1))) == (2, 1) + @test size(tensors(qtn; at = Site(2))) == (2, 1, 3) + @test size(tensors(qtn; at = Site(3))) == (2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) + qtn = Chain(State(), Open(), arrays, order = [:r, :o, :l]) + + @test size(tensors(qtn; at = Site(1))) == (1, 2) + @test size(tensors(qtn; at = Site(2))) == (3, 2, 1) + @test size(tensors(qtn; at = Site(3))) == (2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:nsites(qtn) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + end + end + @testset "Operator" begin + qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) + @test socket(qtn) == Operator() + @test ninputs(qtn) == 3 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(qtn) == Open() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + + arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) + qtn = Chain(Operator(), Open(), arrays) + + @test size(tensors(qtn; at = Site(1))) == (2, 4, 1) + @test size(tensors(qtn; at = Site(2))) == (2, 4, 1, 3) + @test size(tensors(qtn; at = Site(3))) == (2, 4, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 + end + + arrays = [ + permutedims(arrays[1], (3, 1, 2)), + permutedims(arrays[2], (4, 1, 3, 2)), + permutedims(arrays[3], (1, 3, 2)), + ] # now we have (:r, :o, :l, :i) + qtn = Chain(Operator(), Open(), arrays, order = [:r, :o, :l, :i]) + + @test size(tensors(qtn; at = Site(1))) == (1, 2, 4) + @test size(tensors(qtn; at = Site(2))) == (3, 2, 1, 4) + @test size(tensors(qtn; at = Site(3))) == (2, 3, 4) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 + end + end + end @testset "Site" begin using Qrochet: leftsite, rightsite From fa0d5616f94167d60b5a13df8644e7f1cf633149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Thu, 20 Jun 2024 11:31:01 +0200 Subject: [PATCH 2/7] Fix `mixed_canonize` function (#49) * Add missing svd step in mixed_canonize! function * Update and enhance tests for mixed_canonize function * update docstring * Fix normalize! function * Format code --- src/Ansatz/Chain.jl | 7 +++++-- test/Ansatz/Chain_test.jl | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 4bc699b..abb88aa 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -500,7 +500,7 @@ mixed_canonize!(tn::Chain, args...; kwargs...) = mixed_canonize!(boundary(tn), t mixed_canonize!(boundary::Boundary, tn::Chain, center::Site) Transform a `Chain` tensor network into the mixed-canonical form, that is, -for i < center the tensors are left-canonical and for i > center the tensors are right-canonical, +for i < center the tensors are left-canonical and for i >= center the tensors are right-canonical, and in the center there is a matrix with singular values. """ function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could be a range of sites @@ -514,6 +514,9 @@ function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could b canonize_site!(tn, Site(i); direction = :left, method = :qr) end + # center SVD sweep to get singular values + canonize_site!(tn, center; direction = :left, method = :svd) + return tn end @@ -525,7 +528,7 @@ to mixed-canonized form with the given center site. """ function LinearAlgebra.normalize!(tn::Chain, root::Site; p::Real = 2) mixed_canonize!(tn, root) - normalize!(tensors(Quantum(tn); at = root), p) + normalize!(tensors(tn; between = (Site(id(root) - 1), root)), p) return tn end diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 7496d85..d332397 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -359,9 +359,11 @@ qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) canonized = mixed_canonize(qtn, Site(3)) + @test length(tensors(canonized)) == length(tensors(qtn)) + 1 + @test isleftcanonical(canonized, Site(1)) @test isleftcanonical(canonized, Site(2)) - @test !isleftcanonical(canonized, Site(3)) && !isrightcanonical(canonized, Site(3)) + @test isrightcanonical(canonized, Site(3)) @test isrightcanonical(canonized, Site(4)) @test isrightcanonical(canonized, Site(5)) From b86c970dca698116db65788d238bab2922092c8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Thu, 20 Jun 2024 13:32:12 +0200 Subject: [PATCH 3/7] Fix `leftsite` and `rightsite` for `adjoint` `Chain`s (#50) * Fix rightsite and leftsite for adjoint Chains * Add tests for adjoint in Chain --- src/Ansatz/Chain.jl | 13 ++++++++----- test/Ansatz/Chain_test.jl | 13 +++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index abb88aa..bb0ff55 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -183,19 +183,22 @@ function Base.convert(::Type{Chain}, qtn::Product) end leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) -leftsite(::Open, tn::Chain, site::Site) = id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1) : nothing -leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn))) +leftsite(::Open, tn::Chain, site::Site) = + id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1; dual = isdual(site)) : nothing +leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)); dual = isdual(site)) rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) -rightsite(::Open, tn::Chain, site::Site) = id(site) ∈ range(1, nlanes(tn) - 1) ? Site(id(site) + 1) : nothing -rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn))) +rightsite(::Open, tn::Chain, site::Site) = + id(site) ∈ range(1, nlanes(tn) - 1) ? Site(id(site) + 1; dual = isdual(site)) : nothing +rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)); dual = isdual(site)) leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site) leftindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, leftsite(tn, site))) rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site) -rightindex(::Open, tn::Chain, site::Site) = site == Site(nlanes(tn)) ? nothing : rightindex(Periodic(), tn, site) +rightindex(::Open, tn::Chain, site::Site) = + site == Site(nlanes(tn); dual = isdual(site)) ? nothing : rightindex(Periodic(), tn, site) rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, rightsite(tn, site))) Base.adjoint(chain::Chain) = Chain(adjoint(Quantum(chain)), boundary(chain)) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index d332397..1b7ab95 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -380,5 +380,18 @@ isapprox(norm(qtn), 1.0) end + @testset "adjoint" begin + qtn = rand(Chain, Open, State; n = 5, p = 2, χ = 10) + adjoint_qtn = adjoint(qtn) + + for i in 1:nsites(qtn) + i < nsites(qtn) && + @test rightindex(adjoint_qtn, Site(i; dual = true)) == Symbol(String(rightindex(qtn, Site(i))) * "'") + i > 1 && @test leftindex(adjoint_qtn, Site(i; dual = true)) == Symbol(String(leftindex(qtn, Site(i))) * "'") + end + + @test isapprox(contract(TensorNetwork(qtn)), contract(TensorNetwork(adjoint_qtn))) + end + # TODO test `evolve!` methods end From 9b618d4501427abfd2b3308febbe3aede6a8d51c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Tue, 25 Jun 2024 17:32:50 +0200 Subject: [PATCH 4/7] Update Muscle to v0.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e7edab0..d593833 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ QrochetYaoExt = "Yao" [compat] ChainRulesCore = "1.0" ChainRulesTestUtils = "1" -Muscle = "0.1" +Muscle = "0.2" Quac = "0.3" Tenet = "0.6" Yao = "0.8, 0.9" From 266ae528e6f9084c2a41720a6fa8e9f96b1d7f96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 25 Jun 2024 18:41:38 +0200 Subject: [PATCH 5/7] Fix Adapt integration --- Project.toml | 2 ++ ext/QrochetAdaptExt.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index d593833..700c249 100644 --- a/Project.toml +++ b/Project.toml @@ -10,12 +10,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143" Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c" [extensions] +QrochetAdaptExt = "Adapt" QrochetChainRulesCoreExt = "ChainRulesCore" QrochetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"] QrochetQuacExt = "Quac" diff --git a/ext/QrochetAdaptExt.jl b/ext/QrochetAdaptExt.jl index 19b4114..88c6f1a 100644 --- a/ext/QrochetAdaptExt.jl +++ b/ext/QrochetAdaptExt.jl @@ -5,5 +5,7 @@ using Tenet using Adapt Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites) +Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Quantum(x))) +Adapt.adapt_structure(to, x::Chain) = Chain(adapt(to, Quantum(x)), boundary(x)) end From 369dcc82f6be822f4cf08d1fbee6aaf80b31bbac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 25 Jun 2024 18:42:01 +0200 Subject: [PATCH 6/7] Refactor `Product`-`Chain` contraction with `overlap` --- src/Ansatz/Chain.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index bb0ff55..e303ea7 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -748,5 +748,5 @@ function overlap(::State, a::Chain, ::State, b::Chain) end # TODO optimize -overlap(a::Product, b::Chain) = overlap(convert(Chain, a), b) -overlap(a::Chain, b::Product) = overlap(a, convert(Chain, b)) +overlap(a::Product, b::Chain) = contract(TensorNetwork(merge(Quantum(a), Quantum(b)'))) +overlap(a::Chain, b::Product) = contract(TensorNetwork(merge(Quantum(a), Quantum(b)'))) From 0e673be91fc581e646c683b5c952e9556980f649 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Wed, 26 Jun 2024 00:13:35 +0000 Subject: [PATCH 7/7] CompatHelper: add new compat entry for Adapt in [weakdeps] at version 4, (keep existing compat) --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 700c249..93ce92d 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ QrochetQuacExt = "Quac" QrochetYaoExt = "Yao" [compat] +Adapt = "4" ChainRulesCore = "1.0" ChainRulesTestUtils = "1" Muscle = "0.2"