Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Merge branch 'master' into compathelper/new_version/2024-06-18-00-14-…
Browse files Browse the repository at this point in the history
…16-888-02238818074
  • Loading branch information
mofeing authored Jun 26, 2024
2 parents 25a6701 + 0e673be commit 2b62428
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 58 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,24 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143"
Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c"

[extensions]
QrochetAdaptExt = "Adapt"
QrochetChainRulesCoreExt = "ChainRulesCore"
QrochetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"]
QrochetQuacExt = "Quac"
QrochetYaoExt = "Yao"

[compat]
Adapt = "4"
ChainRulesCore = "1.0"
ChainRulesTestUtils = "1"
Muscle = "0.1, 0.2"
Muscle = "0.2"
Quac = "0.3"
Tenet = "0.6"
Yao = "0.8, 0.9"
Expand Down
2 changes: 2 additions & 0 deletions ext/QrochetAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@ using Tenet
using Adapt

Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites)
Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Quantum(x)))
Adapt.adapt_structure(to, x::Chain) = Chain(adapt(to, Quantum(x)), boundary(x))

end
119 changes: 94 additions & 25 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,52 +30,98 @@ function Chain(tn::TensorNetwork, sites, args...; kwargs...)
Chain(Quantum(tn, sites), args...; kwargs...)
end

function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray})
defaultorder(::Type{Chain}, ::State) = (:o, :l, :r)
defaultorder(::Type{Chain}, ::Operator) = (:o, :i, :l, :r)

function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State()))
@assert all(==(3) ndims, arrays) "All arrays must have 3 dimensions"
issetequal(order, defaultorder(Chain, State())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:2n]

_tensors = map(enumerate(arrays)) do (i, array)
Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]])
inds = map(order) do dir
if dir == :o
symbols[i]
elseif dir == :r
symbols[n+mod1(i, n)]
elseif dir == :l
symbols[n+mod1(i - 1, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)

Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end

function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray})
function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State()))
@assert ndims(arrays[1]) == 2 "First array must have 2 dimensions"
@assert all(==(3) ndims, arrays[2:end-1]) "All arrays must have 3 dimensions"
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"
issetequal(order, defaultorder(Chain, State())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:2n-1]
symbols = [nextindex() for _ in 1:2n]

_tensors = map(enumerate(arrays)) do (i, array)
if i == 1
Tensor(array, [symbols[1], symbols[1+n]])
_order = if i == 1
filter(x -> x != :l, order)
elseif i == n
Tensor(array, [symbols[n], symbols[n+mod1(n - 1, n)]])
filter(x -> x != :r, order)
else
Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]])
order
end

inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :r
symbols[n+mod1(i, n)]
elseif dir == :l
symbols[n+mod1(i - 1, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)

Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end

function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray})
@assert all(==(4) ndims, arrays) "All arrays must have 3 dimensions"
function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator()))
@assert all(==(4) ndims, arrays) "All arrays must have 4 dimensions"
issetequal(order, defaultorder(Chain, Operator())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:3n]

_tensors = map(enumerate(arrays)) do (i, array)
Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]])
inds = map(order) do dir
if dir == :o
symbols[i]
elseif dir == :i
symbols[i+n]
elseif dir == :l
symbols[2n+mod1(i - 1, n)]
elseif dir == :r
symbols[2n+mod1(i, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
Expand All @@ -84,22 +130,39 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray})
Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end

function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray})
function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator()))
@assert ndims(arrays[1]) == 3 "First array must have 3 dimensions"
@assert all(==(4) ndims, arrays[2:end-1]) "All arrays must have 4 dimensions"
@assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions"
issetequal(order, defaultorder(Chain, Operator())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:3n-1]

_tensors = map(enumerate(arrays)) do (i, array)
if i == 1
Tensor(array, [symbols[1], symbols[n+1], symbols[1+2n]])
_order = if i == 1
filter(x -> x != :l, order)
elseif i == n
Tensor(array, [symbols[n], symbols[2n], symbols[2n+mod1(n - 1, n)]])
filter(x -> x != :r, order)
else
Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]])
order
end

inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :i
symbols[i+n]
elseif dir == :l
symbols[2n+mod1(i - 1, n)]
elseif dir == :r
symbols[2n+mod1(i, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
Expand All @@ -120,19 +183,22 @@ function Base.convert(::Type{Chain}, qtn::Product)
end

leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site)
leftsite(::Open, tn::Chain, site::Site) = id(site) range(2, nlanes(tn)) ? Site(id(site) - 1) : nothing
leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)))
leftsite(::Open, tn::Chain, site::Site) =
id(site) range(2, nlanes(tn)) ? Site(id(site) - 1; dual = isdual(site)) : nothing
leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)); dual = isdual(site))

rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site)
rightsite(::Open, tn::Chain, site::Site) = id(site) range(1, nlanes(tn) - 1) ? Site(id(site) + 1) : nothing
rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)))
rightsite(::Open, tn::Chain, site::Site) =
id(site) range(1, nlanes(tn) - 1) ? Site(id(site) + 1; dual = isdual(site)) : nothing
rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)); dual = isdual(site))

leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site)
leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site)
leftindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, leftsite(tn, site)))

rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site)
rightindex(::Open, tn::Chain, site::Site) = site == Site(nlanes(tn)) ? nothing : rightindex(Periodic(), tn, site)
rightindex(::Open, tn::Chain, site::Site) =
site == Site(nlanes(tn); dual = isdual(site)) ? nothing : rightindex(Periodic(), tn, site)
rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, rightsite(tn, site)))

Base.adjoint(chain::Chain) = Chain(adjoint(Quantum(chain)), boundary(chain))
Expand Down Expand Up @@ -437,7 +503,7 @@ mixed_canonize!(tn::Chain, args...; kwargs...) = mixed_canonize!(boundary(tn), t
mixed_canonize!(boundary::Boundary, tn::Chain, center::Site)
Transform a `Chain` tensor network into the mixed-canonical form, that is,
for i < center the tensors are left-canonical and for i > center the tensors are right-canonical,
for i < center the tensors are left-canonical and for i >= center the tensors are right-canonical,
and in the center there is a matrix with singular values.
"""
function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could be a range of sites
Expand All @@ -451,6 +517,9 @@ function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could b
canonize_site!(tn, Site(i); direction = :left, method = :qr)
end

# center SVD sweep to get singular values
canonize_site!(tn, center; direction = :left, method = :svd)

return tn
end

Expand All @@ -462,7 +531,7 @@ to mixed-canonized form with the given center site.
"""
function LinearAlgebra.normalize!(tn::Chain, root::Site; p::Real = 2)
mixed_canonize!(tn, root)
normalize!(tensors(Quantum(tn); at = root), p)
normalize!(tensors(tn; between = (Site(id(root) - 1), root)), p)
return tn
end

Expand Down Expand Up @@ -679,5 +748,5 @@ function overlap(::State, a::Chain, ::State, b::Chain)
end

# TODO optimize
overlap(a::Product, b::Chain) = overlap(convert(Chain, a), b)
overlap(a::Chain, b::Product) = overlap(a, convert(Chain, b))
overlap(a::Product, b::Chain) = contract(TensorNetwork(merge(Quantum(a), Quantum(b)')))
overlap(a::Chain, b::Product) = contract(TensorNetwork(merge(Quantum(a), Quantum(b)')))
Loading

0 comments on commit 2b62428

Please sign in to comment.