Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Fix leftsite and rightsite for adjoint Chains (#50)
Browse files Browse the repository at this point in the history
* Fix rightsite and leftsite for adjoint Chains

* Add tests for adjoint in Chain
  • Loading branch information
jofrevalles authored Jun 20, 2024
1 parent fa0d561 commit b86c970
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,22 @@ function Base.convert(::Type{Chain}, qtn::Product)
end

leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site)
leftsite(::Open, tn::Chain, site::Site) = id(site) range(2, nlanes(tn)) ? Site(id(site) - 1) : nothing
leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)))
leftsite(::Open, tn::Chain, site::Site) =
id(site) range(2, nlanes(tn)) ? Site(id(site) - 1; dual = isdual(site)) : nothing
leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)); dual = isdual(site))

rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site)
rightsite(::Open, tn::Chain, site::Site) = id(site) range(1, nlanes(tn) - 1) ? Site(id(site) + 1) : nothing
rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)))
rightsite(::Open, tn::Chain, site::Site) =
id(site) range(1, nlanes(tn) - 1) ? Site(id(site) + 1; dual = isdual(site)) : nothing
rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)); dual = isdual(site))

leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site)
leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site)
leftindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, leftsite(tn, site)))

rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site)
rightindex(::Open, tn::Chain, site::Site) = site == Site(nlanes(tn)) ? nothing : rightindex(Periodic(), tn, site)
rightindex(::Open, tn::Chain, site::Site) =
site == Site(nlanes(tn); dual = isdual(site)) ? nothing : rightindex(Periodic(), tn, site)
rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, rightsite(tn, site)))

Base.adjoint(chain::Chain) = Chain(adjoint(Quantum(chain)), boundary(chain))
Expand Down
13 changes: 13 additions & 0 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,5 +380,18 @@
isapprox(norm(qtn), 1.0)
end

@testset "adjoint" begin
qtn = rand(Chain, Open, State; n = 5, p = 2, χ = 10)
adjoint_qtn = adjoint(qtn)

for i in 1:nsites(qtn)
i < nsites(qtn) &&
@test rightindex(adjoint_qtn, Site(i; dual = true)) == Symbol(String(rightindex(qtn, Site(i))) * "'")
i > 1 && @test leftindex(adjoint_qtn, Site(i; dual = true)) == Symbol(String(leftindex(qtn, Site(i))) * "'")
end

@test isapprox(contract(TensorNetwork(qtn)), contract(TensorNetwork(adjoint_qtn)))
end

# TODO test `evolve!` methods
end

0 comments on commit b86c970

Please sign in to comment.