Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Extend Tenet.contract to contract :between two Sites (#30)
Browse files Browse the repository at this point in the history
* Impement contract for singular values

* Add tests for extended contract function

* Refactor canonize! function tests with new contract function

* Add additional test

* Apply @mofeing suggestions from code review

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>

* Use ValSplit in extended contract function

* Fix format

* Fix code

* Add contract! wrapper

* Add test_throws for incorrect kwargs in contract!

* Remove unwanted prints

* remove ValSplit

* Change order of lambda retrival

* Apply suggestions from code review

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>

* Fix tests

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
jofrevalles and mofeing authored Mar 14, 2024
1 parent c2c8e22 commit b505360
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
22 changes: 22 additions & 0 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,28 @@ function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open},
Chain(Operator(), Open(), arrays)
end

Tenet.contract(tn::Chain, query::Symbol, args...; kwargs...) = contract!(copy(tn), Val(query), args...; kwargs...)
Tenet.contract!(tn::Chain, query::Symbol, args...; kwargs...) = contract!(tn, Val(query), args...; kwargs...)

function Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol = :left)
Λᵢ = select(tn, :between, site1, site2)
Λᵢ === nothing && return tn

if direction === :right
Γᵢ₊₁ = select(tn, :tensor, site2)
replace!(TensorNetwork(tn), Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ, dims = ()))
elseif direction === :left
Γᵢ = select(tn, :tensor, site1)
replace!(TensorNetwork(tn), Γᵢ => contract(Λᵢ, Γᵢ, dims = ()))
else
throw(ArgumentError("Unknown direction=:$direction"))
end

delete!(TensorNetwork(tn), Λᵢ)

return tn
end

canonize_site(tn::Chain, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...)
canonize_site!(tn::Chain, args...; kwargs...) = canonize_site!(boundary(tn), tn, args...; kwargs...)

Expand Down
51 changes: 42 additions & 9 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,33 @@
@testset "Canonization" begin
using Tenet

@testset "contract" begin
qtn = rand(Chain, Open, State; n = 5, p = 2, χ = 20)
let canonized = canonize(qtn)
@test_throws ArgumentError contract!(canonized, :between, Site(1), Site(2); direction = :dummy)
end

canonized = canonize(qtn)

for i in 1:4
contract_some = contract(canonized, :between, Site(i), Site(i + 1))
Bᵢ = select(contract_some, :tensor, Site(i))

@test isapprox(contract(TensorNetwork(contract_some)), contract(TensorNetwork(qtn)))
@test_throws ArgumentError select(contract_some, :between, Site(i), Site(i + 1))

@test isrightcanonical(contract_some, Site(i))
@test isleftcanonical(
contract(canonized, :between, Site(i), Site(i + 1); direction = :right),
Site(i + 1),
)

Γᵢ = select(canonized, :tensor, Site(i))
Λᵢ₊₁ = select(canonized, :between, Site(i), Site(i + 1))
@test Bᵢ contract(Γᵢ, Λᵢ₊₁; dims = ())
end
end

@testset "canonize_site" begin
qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)])

Expand Down Expand Up @@ -164,28 +191,34 @@
Λ = [select(canonized, :between, Site(i), Site(i + 1)) for i in 1:4]
@test map-> sum(abs2, λ), Λ) ones(length(Λ)) * norm(canonized)^2

for i in 1:4
for i in 1:5
canonized = canonize(qtn)

if i == 1
@test isleftcanonical(canonized, Site(i))
elseif i == 5 # in the limits of the chain, we get the norm of the state
contract!(canonized, :between, Site(i - 1), Site(i); direction = :right)
tensor = select(canonized, :tensor, Site(i))
replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized))
@test isleftcanonical(canonized, Site(i))
else
Γᵢ = select(canonized, :tensor, Site(i))
Λᵢ = pop!(TensorNetwork(canonized), select(canonized, :between, Site(i - 1), Site(i)))
replace!(TensorNetwork(canonized), Γᵢ => contract(Λᵢ, Γᵢ; dims = ()))
contract!(canonized, :between, Site(i - 1), Site(i); direction = :right)
@test isleftcanonical(canonized, Site(i))
end
end

for i in 2:5
for i in 1:5
canonized = canonize(qtn)

if i == 5
if i == 1 # in the limits of the chain, we get the norm of the state
contract!(canonized, :between, Site(i), Site(i + 1); direction = :left)
tensor = select(canonized, :tensor, Site(i))
replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized))
@test isrightcanonical(canonized, Site(i))
elseif i == 5
@test isrightcanonical(canonized, Site(i))
else
Γᵢ = select(canonized, :tensor, Site(i))
Λᵢ₊₁ = pop!(TensorNetwork(canonized), select(canonized, :between, Site(i), Site(i + 1)))
replace!(TensorNetwork(canonized), Γᵢ => contract(Γᵢ, Λᵢ₊₁; dims = ()))
contract!(canonized, :between, Site(i), Site(i + 1); direction = :left)
@test isrightcanonical(canonized, Site(i))
end
end
Expand Down

0 comments on commit b505360

Please sign in to comment.