Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Extend Tenet.contract to contract :between two Sites #30

Merged
merged 15 commits into from
Mar 14, 2024
Merged
22 changes: 22 additions & 0 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,28 @@ function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open},
Chain(Operator(), Open(), arrays)
end

Tenet.contract(tn::Chain, query::Symbol, args...; kwargs...) = contract!(copy(tn), Val(query), args...; kwargs...)
Tenet.contract!(tn::Chain, query::Symbol, args...; kwargs...) = contract!(tn, Val(query), args...; kwargs...)

function Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol = :left)
Λᵢ = select(tn, :between, site1, site2)
Λᵢ === nothing && return tn

if direction === :right
Γᵢ₊₁ = select(tn, :tensor, site2)
replace!(TensorNetwork(tn), Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ, dims = ()))
elseif direction === :left
Γᵢ = select(tn, :tensor, site1)
replace!(TensorNetwork(tn), Γᵢ => contract(Λᵢ, Γᵢ, dims = ()))
else
throw(ArgumentError("Unknown direction=:$direction"))
end

delete!(TensorNetwork(tn), Λᵢ)

return tn
end

canonize_site(tn::Chain, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...)
canonize_site!(tn::Chain, args...; kwargs...) = canonize_site!(boundary(tn), tn, args...; kwargs...)

Expand Down
51 changes: 42 additions & 9 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,33 @@
@testset "Canonization" begin
using Tenet

@testset "contract" begin
qtn = rand(Chain, Open, State; n = 5, p = 2, χ = 20)
let canonized = canonize(qtn)
@test_throws ArgumentError contract!(canonized, :between, Site(1), Site(2); direction = :dummy)
end

canonized = canonize(qtn)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved

for i in 1:4
contract_some = contract(canonized, :between, Site(i), Site(i + 1))
Bᵢ = select(contract_some, :tensor, Site(i))

@test isapprox(contract(TensorNetwork(contract_some)), contract(TensorNetwork(qtn)))
@test_throws ArgumentError select(contract_some, :between, Site(i), Site(i + 1))

@test isrightcanonical(contract_some, Site(i))
@test isleftcanonical(
contract(canonized, :between, Site(i), Site(i + 1); direction = :right),
Site(i + 1),
)

Γᵢ = select(canonized, :tensor, Site(i))
Λᵢ₊₁ = select(canonized, :between, Site(i), Site(i + 1))
@test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims = ())
end
end

@testset "canonize_site" begin
qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)])

Expand Down Expand Up @@ -164,28 +191,34 @@
Λ = [select(canonized, :between, Site(i), Site(i + 1)) for i in 1:4]
@test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2

for i in 1:4
for i in 1:5
canonized = canonize(qtn)

if i == 1
@test isleftcanonical(canonized, Site(i))
elseif i == 5 # in the limits of the chain, we get the norm of the state
contract!(canonized, :between, Site(i - 1), Site(i); direction = :right)
tensor = select(canonized, :tensor, Site(i))
replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized))
@test isleftcanonical(canonized, Site(i))
else
Γᵢ = select(canonized, :tensor, Site(i))
Λᵢ = pop!(TensorNetwork(canonized), select(canonized, :between, Site(i - 1), Site(i)))
replace!(TensorNetwork(canonized), Γᵢ => contract(Λᵢ, Γᵢ; dims = ()))
contract!(canonized, :between, Site(i - 1), Site(i); direction = :right)
@test isleftcanonical(canonized, Site(i))
end
end

for i in 2:5
for i in 1:5
canonized = canonize(qtn)

if i == 5
if i == 1 # in the limits of the chain, we get the norm of the state
contract!(canonized, :between, Site(i), Site(i + 1); direction = :left)
tensor = select(canonized, :tensor, Site(i))
replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized))
@test isrightcanonical(canonized, Site(i))
elseif i == 5
@test isrightcanonical(canonized, Site(i))
else
Γᵢ = select(canonized, :tensor, Site(i))
Λᵢ₊₁ = pop!(TensorNetwork(canonized), select(canonized, :between, Site(i), Site(i + 1)))
replace!(TensorNetwork(canonized), Γᵢ => contract(Γᵢ, Λᵢ₊₁; dims = ()))
contract!(canonized, :between, Site(i), Site(i + 1); direction = :left)
@test isrightcanonical(canonized, Site(i))
end
end
Expand Down
Loading