Skip to content

Commit

Permalink
Improve orthogonalize method efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Nov 8, 2024
1 parent f5d3aa4 commit fcf5b98
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
6 changes: 4 additions & 2 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Graphs:
add_edge!,
add_vertex!,
bfs_tree,
center,
dst,
edges,
edgetype,
Expand Down Expand Up @@ -618,8 +619,9 @@ end
# Orthogonalize an ITensorNetwork towards a region, treating
# the network as a tree spanned by a spanning tree.
function tree_orthogonalize::AbstractITensorNetwork, region::Vector)
region = collect(vertices(steiner_tree(underlying_graph(ψ), region)))
path = post_order_dfs_edges(bfs_tree(ψ, first(region)), first(region))
region_centre =
length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region)
path = post_order_dfs_edges(bfs_tree(ψ, region_centre), region_centre)
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
return orthogonalize_path(ψ, path)
end
Expand Down
3 changes: 2 additions & 1 deletion src/solvers/extract/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
# insert_local_tensors takes that tensor and factorizes it back
# apart and puts it back into the network.
#

function default_extracter(state, projected_operator, region; internal_kwargs)
if isa(region, AbstractEdge)
# TODO: add functionality for orthogonalizing onto a bond so that can be called instead
vsrc, vdst = src(region), dst(region)
state = orthogonalize(state, vsrc)
left_inds = uniqueinds(state[vsrc], state[vdst])
#ToDo: replace with call to factorize
U, S, V = svd(
state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
)
Expand Down
5 changes: 5 additions & 0 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ function set_ortho_region(tn::AbstractTTN, new_region)
end

function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...)
return orthogonalize_ttn(ttn, region; kwargs...)
end

function orthogonalize_ttn(ttn::AbstractTTN, region::Vector; kwargs...)
issetequal(region, ortho_region(ttn)) && return ttn
st = steiner_tree(ttn, union(region, ortho_region(ttn)))
path = post_order_dfs_edges(st, first(region))
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
Expand Down
5 changes: 3 additions & 2 deletions test/test_itensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ using ITensorNetworks:
orthogonalize,
random_tensornetwork,
siteinds,
tree_orthogonalize,
ttn
using LinearAlgebra: factorize
using NamedGraphs: NamedEdge
Expand Down Expand Up @@ -287,13 +288,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test nv(tn_ortho) == 5
@test nv(tn) == 4
@test Z
tn_ortho = orthogonalize(tn, 4 => 3)
tn_ortho = tree_orthogonalize(tn, [3, 4])
= norm_sqr(tn_ortho)
@test nv(tn_ortho) == 4
@test nv(tn) == 4
@test Z

tn_ortho = orthogonalize(tn, 1)
tn_ortho = tree_orthogonalize(tn, 1)
= norm_sqr(tn_ortho)
@test Z
= inner(tn_ortho, tn)
Expand Down

0 comments on commit fcf5b98

Please sign in to comment.