Skip to content

Commit

Permalink
Fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Nov 26, 2024
1 parent 34e8e5e commit d096722
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 18 deletions.
40 changes: 26 additions & 14 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using Graphs:
using ITensors:
ITensors,
ITensor,
@Algorithm_str,
addtags,
combiner,
commoninds,
Expand All @@ -44,7 +45,7 @@ using MacroTools: @capture
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
using NamedGraphs.GraphsExtensions:
, directed_graph, incident_edges, rename_vertices, vertextype
using NDTensors: NDTensors, dim
using NDTensors: NDTensors, dim, Algorithm
using SplitApplyCombine: flatten

abstract type AbstractITensorNetwork{V} <: AbstractDataGraph{V,ITensor,ITensor} end
Expand Down Expand Up @@ -585,17 +586,22 @@ function LinearAlgebra.factorize(tn::AbstractITensorNetwork, edge::Pair; kwargs.
end

# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function orthogonalize_walk(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
return orthogonalize_walk(tn, [edge]; kwargs...)
function gauge_walk(
alg::Algorithm, tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...
)
return gauge_walk(tn, [edge]; kwargs...)
end

function orthogonalize_walk(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
return orthogonalize_walk(tn, edgetype(tn)(edge); kwargs...)
function gauge_walk(alg::Algorithm, tn::AbstractITensorNetwork, edge::Pair; kwargs...)
return gauge_walk(alg::Algorithm, tn, edgetype(tn)(edge); kwargs...)
end

# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function orthogonalize_walk(
tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs...
function gauge_walk(
alg::Algorithm"orthogonalize",
tn::AbstractITensorNetwork,
edges::Vector{<:AbstractEdge};
kwargs...,
)
# tn = factorize(tn, edge; kwargs...)
# # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
Expand All @@ -612,22 +618,28 @@ function orthogonalize_walk(
return tn
end

function orthogonalize_walk(tn::AbstractITensorNetwork, edges::Vector{<:Pair}; kwargs...)
return orthogonalize_walk(tn, edgetype(tn).(edges); kwargs...)
function gauge_walk(
alg::Algorithm, tn::AbstractITensorNetwork, edges::Vector{<:Pair}; kwargs...
)
return gauge_walk(alg, tn, edgetype(tn).(edges); kwargs...)
end

# Orthogonalize an ITensorNetwork towards a region, treating
# Gauge a ITensorNetwork towards a region, treating
# the network as a tree spanned by a spanning tree.
function tree_orthogonalize(ψ::AbstractITensorNetwork, region::Vector)
function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region::Vector)
region_center =
length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region)
path = post_order_dfs_edges(bfs_tree(ψ, region_center), region_center)
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
return orthogonalize_walk(ψ, path)
return gauge_walk(alg, ψ, path)
end

function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region)
return tree_gauge(alg, ψ, [region])
end

function tree_orthogonalize::AbstractITensorNetwork, region)
return tree_orthogonalize(ψ, [region])
function tree_orthogonalize::AbstractITensorNetwork, region; kwargs...)
return tree_gauge(Algorithm("orthogonalize"), ψ, region; kwargs...)
end

# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
Expand Down
12 changes: 8 additions & 4 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using NamedGraphs.GraphsExtensions:
a_star
using NamedGraphs: namedgraph_a_star, steiner_tree
using IsApprox: IsApprox, Approx
using ITensors: ITensors, @Algorithm_str, directsum, hasinds, permute, plev
using ITensors: ITensors, Algorithm, @Algorithm_str, directsum, hasinds, permute, plev
using ITensorMPS: ITensorMPS, linkind, loginner, lognorm, orthogonalize
using TupleTools: TupleTools

Expand All @@ -35,19 +35,23 @@ function set_ortho_region(tn::AbstractTTN, new_region)
return error("Not implemented")
end

function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...)
function gauge(alg::Algorithm, ttn::AbstractTTN, region::Vector; kwargs...)
issetequal(region, ortho_region(ttn)) && return ttn
st = steiner_tree(ttn, union(region, ortho_region(ttn)))
path = post_order_dfs_edges(st, first(region))
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
if !isempty(path)
ttn = typeof(ttn)(orthogonalize_walk(ITensorNetwork(ttn), path; kwargs...))
ttn = typeof(ttn)(gauge_walk(alg, ITensorNetwork(ttn), path; kwargs...))
end
return set_ortho_region(ttn, region)
end

function gauge(alg::Algorithm, ttn::AbstractTTN, region; kwargs...)
return gauge(alg, ttn, [region]; kwargs...)
end

function ITensorMPS.orthogonalize(ttn::AbstractTTN, region; kwargs...)
return orthogonalize(ttn, [region]; kwargs...)
return gauge(Algorithm("orthogonalize"), ttn, region; kwargs...)
end

function tree_orthogonalize(ttn::AbstractTTN, args...; kwargs...)
Expand Down

0 comments on commit d096722

Please sign in to comment.