Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Fix canonize function and enhance index semantics (#8)
Browse files Browse the repository at this point in the history
* Fix rightindex and leftindex logic

* Fix canonize! function and enhance syntax

* Fix typo and add tests

* Enhance tests

* Update code for Periodic boundary

* Fix leftsite and rightsite function

* Add Site testset

* Update syntax
  • Loading branch information
jofrevalles authored Feb 16, 2024
1 parent 482ac90 commit 7b4d4ef
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 20 deletions.
55 changes: 36 additions & 19 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,57 +112,74 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray})
Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end

leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site)
function leftsite(::Open, tn::Chain, site::Site)
site.id range(2, length(sites(tn))) && throw(ArgumentError("Invalid site $site"))
Site(site.id - 1)
end
leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id - 1, length(sites(tn))))

rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site)
function rightsite(::Open, tn::Chain, site::Site)
site.id range(1, length(sites(tn))-1) && throw(ArgumentError("Invalid site $site"))
Site(site.id + 1)
end
rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id + 1, length(sites(tn))))

leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site)
leftindex(::Periodic, tn::Chain, site::Site) = (select(tn, :tensor, site)|>inds)[end-1]
function leftindex(::Open, tn::Chain, site::Site)
function leftindex(::Union{Open, Periodic}, tn::Chain, site::Site)
if site == site"1"
nothing
elseif site == Site(nsites(tn)) # TODO review
(select(tn, :tensor, site)|>inds)[end]
else
(select(tn, :tensor, site)|>inds)[end-1]
(select(tn, :tensor, site)|>inds) (select(tn, :tensor, leftsite(tn, site))|>inds) |> only
end
end

rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site)
rightindex(::Periodic, tn::Chain, site::Site) = (select(tn, :tensor, site)|>inds)[end]
function rightindex(::Open, tn::Chain, site::Site)
function rightindex(::Union{Open, Periodic}, tn::Chain, site::Site)
if site == Site(nsites(tn)) # TODO review
nothing
else
(select(tn, :tensor, site)|>inds)[end]
(select(tn, :tensor, site)|>inds) (select(tn, :tensor, rightsite(tn, site))|>inds) |> only
end
end

canonize(tn::Chain, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...)
canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...)

# NOTE spectral weights are stored in a vector connected to the now virtual hyperindex!
function canonize!(::Open, tn::Chain, site::Site; direction::Symbol)
# NOTE: in mode == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex!
function canonize!(::Open, tn::Chain, site::Site; direction::Symbol, mode = :qr)
left_inds = Symbol[]
right_inds = Symbol[]

virtualind = if direction === :left
site == Site(1) && throw(ArgumentError("Cannot left-canonize left-most tensor"))
push!(left_inds, leftindex(tn, site))

site == Site(nsites(tn)) || push!(right_inds, rightindex(tn, site))
push!(right_inds, Quantum(tn)[site])

only(left_inds)
elseif direction === :right
site == Site(nsites(tn)) && throw(ArgumentError("Cannot right-canonize right-most tensor"))
push!(right_inds, rightindex(tn, site))

site == Site(1) || push!(left_inds, leftindex(tn, site))
push!(left_inds, Quantum(tn)[site])

only(right_inds)
elseif direction === :right
site == Site(1) && throw(ArgumentError("Cannot left-canonize left-most tensor"))
push!(right_inds, leftindex(tn, site))

site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site))
push!(left_inds, Quantum(tn)[site])

only(right_inds)
else
throw(ArgumentError("Unknown direction=:$direction"))
end

tmpind = gensym(:tmp)
qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind)
if mode == :qr
qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind)
elseif mode == :svd
svd!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind)
else
throw(ArgumentError("Unknown mode=:$mode"))
end

contract!(TensorNetwork(tn), virtualind)
replace!(TensorNetwork(tn), tmpind => virtualind)
Expand Down
2 changes: 1 addition & 1 deletion src/Qrochet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export Product
include("Ansatz/Chain.jl")
export Chain
export MPS, pMPS, MPO, pMPO
export leftindex, rightindex, canonize!
export leftindex, rightindex, canonize, canonize!

# reexports from Tenet
using Tenet
Expand Down
72 changes: 72 additions & 0 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,76 @@
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"])
@test boundary(qtn) == Open()

@testset "Site" begin
using Qrochet: leftsite, rightsite
qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3])

@test leftsite(qtn, Site(1)) == Site(3)
@test leftsite(qtn, Site(2)) == Site(1)
@test leftsite(qtn, Site(3)) == Site(2)

@test rightsite(qtn, Site(1)) == Site(2)
@test rightsite(qtn, Site(2)) == Site(3)
@test rightsite(qtn, Site(3)) == Site(1)

qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)])

@test_throws ArgumentError leftsite(qtn, Site(1))
@test_throws ArgumentError rightsite(qtn, Site(3))

@test leftsite(qtn, Site(2)) == Site(1)
@test leftsite(qtn, Site(3)) == Site(2)

@test rightsite(qtn, Site(2)) == Site(3)
@test rightsite(qtn, Site(1)) == Site(2)
end

@testset "canonize" begin
using Tenet

function is_left_canonical(qtn, s::Site)
label_r = rightindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_r => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_r), size(A, label_r)), atol=1e-12)
catch
return false
end
end

function is_right_canonical(qtn, s::Site)
label_l = leftindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_l => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_l), size(A, label_l)), atol=1e-12)
catch
return false
end
end

qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)])

@test_throws ArgumentError canonize!(qtn, Site(1); direction=:right)
@test_throws ArgumentError canonize!(qtn, Site(3); direction=:left)

for mode in [:qr, :svd]
for i in 1:length(sites(qtn))
if i != 1
canonized = canonize(qtn, Site(i); direction=:right, mode=mode)
@test is_right_canonical(canonized, Site(i))
@test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn)))
elseif i != length(sites(qtn))
canonized = canonize(qtn, Site(i); direction=:left, mode=mode)
@test is_left_canonical(canonized, Site(i))
@test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn)))
end
end
end

# Ensure that svd creates a new tensor
@test length(tensors(canonize(qtn, Site(2); direction=:right, mode=:svd))) == 4
end
end

0 comments on commit 7b4d4ef

Please sign in to comment.