Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Fix canonize function and enhance index semantics #8

Merged
merged 8 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,57 +112,66 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray})
Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end

rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site)
rightsite(::Union{Open, Periodic}, tn::Chain, site::Site) = Site(site.id + 1)

leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site)
leftsite(::Union{Open, Periodic}, tn::Chain, site::Site) = Site(site.id - 1)

leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site)
leftindex(::Periodic, tn::Chain, site::Site) = (select(tn, :tensor, site)|>inds)[end-1]
function leftindex(::Open, tn::Chain, site::Site)
function leftindex(::Union{Open, Periodic}, tn::Chain, site::Site)
if site == site"1"
nothing
elseif site == Site(nsites(tn)) # TODO review
(select(tn, :tensor, site)|>inds)[end]
else
(select(tn, :tensor, site)|>inds)[end-1]
(select(tn, :tensor, site)|>inds) ∩ (select(tn, :tensor, leftsite(tn, site))|>inds) |> only
end
end

rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site)
rightindex(::Periodic, tn::Chain, site::Site) = (select(tn, :tensor, site)|>inds)[end]
function rightindex(::Open, tn::Chain, site::Site)
function rightindex(::Union{Open, Periodic}, tn::Chain, site::Site)
if site == Site(nsites(tn)) # TODO review
nothing
else
(select(tn, :tensor, site)|>inds)[end]
(select(tn, :tensor, site)|>inds) ∩ (select(tn, :tensor, rightsite(tn, site))|>inds) |> only
end
end

canonize(tn::Chain, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...)
canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...)

# NOTE spectral weights are stored in a vector connected to the now virtual hyperindex!
function canonize!(::Open, tn::Chain, site::Site; direction::Symbol)
# NOTE: in mode == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex!
function canonize!(::Open, tn::Chain, site::Site; direction::Symbol, mode = :qr)
left_inds = Symbol[]
right_inds = Symbol[]

virtualind = if direction === :left
site == Site(1) && throw(ArgumentError("Cannot left-canonize left-most tensor"))
push!(left_inds, leftindex(tn, site))

site == Site(nsites(tn)) || push!(right_inds, rightindex(tn, site))
push!(right_inds, Quantum(tn)[site])

only(left_inds)
elseif direction === :right
site == Site(nsites(tn)) && throw(ArgumentError("Cannot right-canonize right-most tensor"))
push!(right_inds, rightindex(tn, site))

site == Site(1) || push!(left_inds, leftindex(tn, site))
push!(left_inds, Quantum(tn)[site])

only(right_inds)
elseif direction === :right
site == Site(1) && throw(ArgumentError("Cannot left-canonize left-most tensor"))
push!(right_inds, leftindex(tn, site))

site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site))
push!(left_inds, Quantum(tn)[site])

only(right_inds)
else
throw(ArgumentError("Unknown direction=:$direction"))
end

tmpind = gensym(:tmp)
qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind)
if mode == :qr
qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind)
elseif mode == :svd
svd!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind)
else
throw(ArgumentError("Unknown mode=:$mode"))
end

contract!(TensorNetwork(tn), virtualind)
replace!(TensorNetwork(tn), tmpind => virtualind)
Expand Down
2 changes: 1 addition & 1 deletion src/Qrochet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export Product
include("Ansatz/Chain.jl")
export Chain
export MPS, pMPS, MPO, pMPO
export leftindex, rightindex, canonize!
export leftindex, rightindex, canonize, canonize!

# reexports from Tenet
using Tenet
Expand Down
48 changes: 48 additions & 0 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,52 @@
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"])
@test boundary(qtn) == Open()

@testset "canonize" begin
using Tenet

function is_left_canonical(qtn, s::Site)
label_r = rightindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_r => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_r), size(A, label_r)), atol=1e-12)
catch
return false
end
end

function is_right_canonical(qtn, s::Site)
label_l = leftindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_l => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_l), size(A, label_l)), atol=1e-12)
catch
return false
end
end

qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)])

@test_throws ArgumentError canonize!(qtn, Site(1); direction=:right)
@test_throws ArgumentError canonize!(qtn, Site(3); direction=:left)

for mode in [:qr, :svd]
for i in 1:length(sites(qtn))
if i != 1
canonized = canonize(qtn, Site(i); direction=:right, mode=mode)
@test is_right_canonical(canonized, Site(i))
@test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn)))
elseif i != length(sites(qtn))
canonized = canonize(qtn, Site(i); direction=:left, mode=mode)
@test is_left_canonical(canonized, Site(i))
@test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn)))
end
end
end

# Ensure that svd creates a new tensor
@test length(tensors(canonize(qtn, Site(2); direction=:right, mode=:svd))) == 4
end
end
Loading