From 677d49b93b1e8240666889d5b5fab24a9a56eee0 Mon Sep 17 00:00:00 2001 From: Joey Date: Wed, 10 Jul 2024 15:58:19 +0200 Subject: [PATCH 1/7] Testing in-place vs out-of-place versions --- src/caches/beliefpropagationcache.jl | 63 +++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 2ce338f3..f3b42b7e 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -178,7 +178,30 @@ end """ Do a sequential update of the message tensors on `edges` """ -function update( +function update!( + bp_cache::BeliefPropagationCache, + edges::Vector{<:PartitionEdge}; + (update_diff!)=nothing, + kwargs..., +) + prev_mts = copy(messages(bp_cache)) + mts = messages(bp_cache) + for e in edges + if !haskey(prev_mts, e) + set!(prev_mts, e, default_message(bp_cache, e)) + end + set!(mts, e, update_message(bp_cache, e; kwargs...)) + if !isnothing(update_diff!) + update_diff![] += message_diff(prev_mts[e], mts[e]) + end + end + return bp_cache +end + +""" +Do a sequential update of the message tensors on `edges` +""" +function update_V1( bp_cache::BeliefPropagationCache, edges::Vector{<:PartitionEdge}; (update_diff!)=nothing, @@ -187,7 +210,7 @@ function update( bp_cache_updated = copy(bp_cache) mts = messages(bp_cache_updated) for e in edges - set!(mts, e, update_message(bp_cache_updated, e; kwargs...)) + set!(mts, e, update_message(bp_cache, e; kwargs...)) if !isnothing(update_diff!) update_diff![] += message_diff(message(bp_cache, e), mts[e]) end @@ -195,9 +218,28 @@ function update( return bp_cache_updated end +""" +Out of place version +""" +function update_V2( + bp_cache::BeliefPropagationCache, + edges::Vector{<:PartitionEdge}; + (update_diff!)=nothing, + kwargs..., +) + bp_cache_updated = copy(bp_cache) + bp_cache_updated = update!(bp_cache_updated, edges; (update_diff!), kwargs...) + return bp_cache_updated +end + + """ Update the message tensor on a single edge """ +function update!(bp_cache::BeliefPropagationCache, edge::PartitionEdge; kwargs...) + return update!(bp_cache, [edge]; kwargs...) +end + function update(bp_cache::BeliefPropagationCache, edge::PartitionEdge; kwargs...) return update(bp_cache, [edge]; kwargs...) end @@ -205,7 +247,7 @@ end """ Do parallel updates between groups of edges of all message tensors Currently we send the full message tensor data struct to update for each edge_group. But really we only need the -mts relevant to that group. +mts relevant to that group. Out-of-place only for now. """ function update( bp_cache::BeliefPropagationCache, @@ -239,7 +281,7 @@ function update( end for i in 1:maxiter diff = compute_error ? Ref(0.0) : nothing - bp_cache = update(bp_cache, edges; (update_diff!)=diff, kwargs...) + bp_cache = update_V1(bp_cache, edges; (update_diff!)=diff, kwargs...) if compute_error && (diff.x / length(edges)) <= tol if verbose println("BP converged to desired precision after $i iterations.") @@ -253,8 +295,7 @@ end """ Update the tensornetwork inside the cache """ -function update_factors(bp_cache::BeliefPropagationCache, factors) - bp_cache = copy(bp_cache) +function update_factors!(bp_cache::BeliefPropagationCache, factors) tn = tensornetwork(bp_cache) for vertex in eachindex(factors) # TODO: Add a check that this preserves the graph structure. @@ -263,6 +304,16 @@ function update_factors(bp_cache::BeliefPropagationCache, factors) return bp_cache end +function update_factors(bp_cache::BeliefPropagationCache, factors) + bp_cache_updated = copy(bp_cache) + bp_cache_updated = update_factors!(bp_cache_updated, factors) + return bp_cache_updated +end + +function update_factor!(bp_cache, vertex, factor) + return update_factors!(bp_cache, Dictionary([vertex], [factor])) +end + function update_factor(bp_cache, vertex, factor) return update_factors(bp_cache, Dictionary([vertex], [factor])) end From bdf2b4dedda5add729703c402056e5a045e13966 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 11 Jul 2024 13:11:53 +0200 Subject: [PATCH 2/7] avoid use of get --- src/caches/beliefpropagationcache.jl | 106 ++++++++------------------- test/test_belief_propagation.jl | 2 - 2 files changed, 32 insertions(+), 76 deletions(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index f3b42b7e..54a392c3 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -1,6 +1,6 @@ using Graphs: IsDirected using SplitApplyCombine: group -using LinearAlgebra: diag +using LinearAlgebra: diag, dot using ITensors: dir using ITensorMPS: ITensorMPS using NamedGraphs.PartitionedGraphs: @@ -12,16 +12,21 @@ using NamedGraphs.PartitionedGraphs: partitionedges, unpartitioned_graph using SimpleTraits: SimpleTraits, Not, @traitfn +using NDTensors: NDTensors -default_message(inds_e) = ITensor[denseblocks(delta(i)) for i in inds_e] +default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e] default_messages(ptn::PartitionedGraph) = Dictionary() -default_message_norm(m::ITensor) = norm(m) + function default_message_update(contract_list::Vector{ITensor}; kwargs...) sequence = optimal_contraction_sequence(contract_list) updated_messages = contract(contract_list; sequence, kwargs...) - updated_messages /= norm(updated_messages) + message_norm = norm(updated_messages) + if !iszero(message_norm) + updated_messages /= message_norm + end return ITensor[updated_messages] end + @traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing @traitfn function default_bp_maxiter(g::::IsDirected) return default_bp_maxiter(undirected_graph(underlying_graph(g))) @@ -30,17 +35,16 @@ default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertice function default_partitioned_vertices(f::AbstractFormNetwork) return group(v -> original_state_vertex(f, v), vertices(f)) end -default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) +default_cache_update_kwargs(cache) = (; maxiter=25, tol=1e-8) function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork) return (; partitioned_vertices=default_partitioned_vertices(ψ)) end -function message_diff( - message_a::Vector{ITensor}, message_b::Vector{ITensor}; message_norm=default_message_norm -) +#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages +function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) lhs, rhs = contract(message_a), contract(message_b) - norm_lhs, norm_rhs = message_norm(lhs), message_norm(rhs) - return 0.5 * norm((denseblocks(lhs) / norm_lhs) - (denseblocks(rhs) / norm_rhs)) + f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs))) + return 1 - f end struct BeliefPropagationCache{PTN,MTS,DM} @@ -96,13 +100,17 @@ for f in [ end end +NDTensors.scalartype(bp_cache) = scalartype(tensornetwork(bp_cache)) + function default_message(bp_cache::BeliefPropagationCache, edge::PartitionEdge) - return default_message(bp_cache)(linkinds(bp_cache, edge)) + return default_message(bp_cache)(scalartype(bp_cache), linkinds(bp_cache, edge)) end function message(bp_cache::BeliefPropagationCache, edge::PartitionEdge) mts = messages(bp_cache) - return get(mts, edge, default_message(bp_cache, edge)) + #return get(mts, edge, default_message(bp_cache, edge)) + haskey(mts, edge) && return mts[edge] + return default_message(bp_cache, edge) end function messages(bp_cache::BeliefPropagationCache, edges; kwargs...) return map(edge -> message(bp_cache, edge; kwargs...), edges) @@ -148,15 +156,16 @@ end function environment(bp_cache::BeliefPropagationCache, verts::Vector) partition_verts = partitionvertices(bp_cache, verts) messages = environment(bp_cache, partition_verts) - central_tensors = ITensor[ - tensornetwork(bp_cache)[v] for v in setdiff(vertices(bp_cache, partition_verts), verts) - ] + central_tensors = factors(bp_cache, setdiff(vertices(bp_cache, partition_verts), verts)) return vcat(messages, central_tensors) end +function factors(bp_cache::BeliefPropagationCache, verts::Vector) + return ITensor[tensornetwork(bp_cache)[v] for v in verts] +end + function factor(bp_cache::BeliefPropagationCache, vertex::PartitionVertex) - ptn = partitioned_tensornetwork(bp_cache) - return collect(eachtensor(subgraph(ptn, vertex))) + return factors(bp_cache, vertices(bp_cache, vertex)) end """ @@ -178,30 +187,7 @@ end """ Do a sequential update of the message tensors on `edges` """ -function update!( - bp_cache::BeliefPropagationCache, - edges::Vector{<:PartitionEdge}; - (update_diff!)=nothing, - kwargs..., -) - prev_mts = copy(messages(bp_cache)) - mts = messages(bp_cache) - for e in edges - if !haskey(prev_mts, e) - set!(prev_mts, e, default_message(bp_cache, e)) - end - set!(mts, e, update_message(bp_cache, e; kwargs...)) - if !isnothing(update_diff!) - update_diff![] += message_diff(prev_mts[e], mts[e]) - end - end - return bp_cache -end - -""" -Do a sequential update of the message tensors on `edges` -""" -function update_V1( +function update( bp_cache::BeliefPropagationCache, edges::Vector{<:PartitionEdge}; (update_diff!)=nothing, @@ -210,7 +196,7 @@ function update_V1( bp_cache_updated = copy(bp_cache) mts = messages(bp_cache_updated) for e in edges - set!(mts, e, update_message(bp_cache, e; kwargs...)) + set!(mts, e, update_message(bp_cache_updated, e; kwargs...)) if !isnothing(update_diff!) update_diff![] += message_diff(message(bp_cache, e), mts[e]) end @@ -218,28 +204,9 @@ function update_V1( return bp_cache_updated end -""" -Out of place version -""" -function update_V2( - bp_cache::BeliefPropagationCache, - edges::Vector{<:PartitionEdge}; - (update_diff!)=nothing, - kwargs..., -) - bp_cache_updated = copy(bp_cache) - bp_cache_updated = update!(bp_cache_updated, edges; (update_diff!), kwargs...) - return bp_cache_updated -end - - """ Update the message tensor on a single edge """ -function update!(bp_cache::BeliefPropagationCache, edge::PartitionEdge; kwargs...) - return update!(bp_cache, [edge]; kwargs...) -end - function update(bp_cache::BeliefPropagationCache, edge::PartitionEdge; kwargs...) return update(bp_cache, [edge]; kwargs...) end @@ -247,7 +214,7 @@ end """ Do parallel updates between groups of edges of all message tensors Currently we send the full message tensor data struct to update for each edge_group. But really we only need the -mts relevant to that group. Out-of-place only for now. +mts relevant to that group. """ function update( bp_cache::BeliefPropagationCache, @@ -281,7 +248,7 @@ function update( end for i in 1:maxiter diff = compute_error ? Ref(0.0) : nothing - bp_cache = update_V1(bp_cache, edges; (update_diff!)=diff, kwargs...) + bp_cache = update(bp_cache, edges; (update_diff!)=diff, kwargs...) if compute_error && (diff.x / length(edges)) <= tol if verbose println("BP converged to desired precision after $i iterations.") @@ -295,7 +262,8 @@ end """ Update the tensornetwork inside the cache """ -function update_factors!(bp_cache::BeliefPropagationCache, factors) +function update_factors(bp_cache::BeliefPropagationCache, factors) + bp_cache = copy(bp_cache) tn = tensornetwork(bp_cache) for vertex in eachindex(factors) # TODO: Add a check that this preserves the graph structure. @@ -304,16 +272,6 @@ function update_factors!(bp_cache::BeliefPropagationCache, factors) return bp_cache end -function update_factors(bp_cache::BeliefPropagationCache, factors) - bp_cache_updated = copy(bp_cache) - bp_cache_updated = update_factors!(bp_cache_updated, factors) - return bp_cache_updated -end - -function update_factor!(bp_cache, vertex, factor) - return update_factors!(bp_cache, Dictionary([vertex], [factor])) -end - function update_factor(bp_cache, vertex, factor) return update_factors(bp_cache, Dictionary([vertex], [factor])) end diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index 66cf25d7..37222034 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -1,8 +1,6 @@ @eval module $(gensym()) using Compat: Compat using Graphs: vertices -# Trigger package extension. -using GraphsFlows: GraphsFlows using ITensorNetworks: ITensorNetworks, BeliefPropagationCache, From fc96369d7c73c65811b8d650176b416a310078b3 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 11 Jul 2024 13:23:22 +0200 Subject: [PATCH 3/7] Restore file --- test/test_treetensornetworks/test_solvers/test_dmrg.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/test_treetensornetworks/test_solvers/test_dmrg.jl b/test/test_treetensornetworks/test_solvers/test_dmrg.jl index 3bc75d3d..76addf78 100644 --- a/test/test_treetensornetworks/test_solvers/test_dmrg.jl +++ b/test/test_treetensornetworks/test_solvers/test_dmrg.jl @@ -184,13 +184,11 @@ end else # when using no QNs, autofermion breaks # ToDo reference Issue in ITensors ITensors.disable_auto_fermion() end - - tooth_lengths = fill(2, 3) - c = named_comb_tree(tooth_lengths) s = siteinds("S=1/2", c; conserve_qns=use_qns) + os = ModelHamiltonians.heisenberg(c) + H = ttn(os, s) - e, psi = dmrg(H, psi; dmrg_kwargs) # make init_state d = Dict() @@ -299,4 +297,4 @@ end @test all(edge_data(linkdims(psi)) .<= maxdim) end -end +end \ No newline at end of file From 6d0bb52d6cd124624943a8bba3d6c14fe1e24b3f Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 11 Jul 2024 13:24:29 +0200 Subject: [PATCH 4/7] Restore file --- test/test_treetensornetworks/test_solvers/test_dmrg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_treetensornetworks/test_solvers/test_dmrg.jl b/test/test_treetensornetworks/test_solvers/test_dmrg.jl index 76addf78..b352d43c 100644 --- a/test/test_treetensornetworks/test_solvers/test_dmrg.jl +++ b/test/test_treetensornetworks/test_solvers/test_dmrg.jl @@ -297,4 +297,4 @@ end @test all(edge_data(linkdims(psi)) .<= maxdim) end -end \ No newline at end of file +end From b5699e54cedbda3f4eed19ba645115ebf76d63c2 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 11 Jul 2024 13:25:38 +0200 Subject: [PATCH 5/7] Restore file --- src/caches/beliefpropagationcache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index be026744..33f2051d 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -14,7 +14,7 @@ using NamedGraphs.PartitionedGraphs: using SimpleTraits: SimpleTraits, Not, @traitfn using NDTensors: NDTensors -default_message(inds_e) = ITensor[denseblocks(delta(i)) for i in inds_e] +default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e] default_messages(ptn::PartitionedGraph) = Dictionary() function default_message_update(contract_list::Vector{ITensor}; kwargs...) sequence = optimal_contraction_sequence(contract_list) From e1e13ba662472ac31970604ed6cb9595f5530d8f Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 11 Jul 2024 13:26:34 +0200 Subject: [PATCH 6/7] Restore file --- src/caches/beliefpropagationcache.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 33f2051d..a3094f12 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -25,7 +25,6 @@ function default_message_update(contract_list::Vector{ITensor}; kwargs...) end return ITensor[updated_messages] end - @traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing @traitfn function default_bp_maxiter(g::::IsDirected) return default_bp_maxiter(undirected_graph(underlying_graph(g))) From 22515eaa71660c1a1c01eef7243f26514c1913e5 Mon Sep 17 00:00:00 2001 From: Joseph Tindall <51231103+JoeyT1994@users.noreply.github.com> Date: Sat, 27 Jul 2024 19:51:08 -0400 Subject: [PATCH 7/7] Use `get` again but with first argument --- src/caches/beliefpropagationcache.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index a3094f12..d80b2644 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -106,8 +106,7 @@ end function message(bp_cache::BeliefPropagationCache, edge::PartitionEdge) mts = messages(bp_cache) - haskey(mts, edge) && return mts[edge] - return default_message(bp_cache, edge) + return get(() -> default_message(bp_cache, edge), mts, edge) end function messages(bp_cache::BeliefPropagationCache, edges; kwargs...) return map(edge -> message(bp_cache, edge; kwargs...), edges)