Skip to content

Commit

Permalink
revise the ExactTreeWidth solver with TreeWidthSolver 0.2.0 (#46)
Browse files Browse the repository at this point in the history
* revise the ExactTreeWidth solver with TreeWidthSolver 0.2.0

* fix contruction with scalars

* add tests for ExactTreewidth solver with scalars

* fixed test
  • Loading branch information
ArrogantGao authored Aug 7, 2024
1 parent 6f0de48 commit 36e0250
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ KaHyPar = "0.3"
StatsBase = "0.34"
Suppressor = "0.2"
LuxorGraphPlot = "0.5.1"
TreeWidthSolver = "0.1.0"
TreeWidthSolver = "0.2"
julia = "1.9"

[extras]
Expand Down
27 changes: 15 additions & 12 deletions src/treewidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ ab, ab -> a
"""
function exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; α::TA = 0.0, temperature::TT = 0.0, nrepeat=1) where {VT,ET,TA,TT}
indicies = collect(keys(incidence_list.e2v))
tensors = collect(keys(incidence_list.v2e))
weights = [log2_edge_sizes[e] for e in indicies]
line_graph = il2lg(incidence_list, indicies)

scalars = [i for i in tensors if isempty(incidence_list.v2e[i])]
contraction_trees = Vector{Union{ContractionTree, VT}}()

# avoid the case that the line graph is not connected
Expand All @@ -71,8 +73,9 @@ function exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_
contraction_tree = eo2ct(elimination_order, incidence_list, log2_edge_sizes, α, temperature, nrepeat)
push!(contraction_trees, contraction_tree)
end

return reduce((x,y) -> ContractionTree(x, y), contraction_trees)

# add the scalars back to the contraction tree
return reduce((x,y) -> ContractionTree(x, y), contraction_trees scalars)
end

# transform incidence list to line graph
Expand Down Expand Up @@ -104,16 +107,16 @@ function eo2ct(elimination_order::EliminationOrder, incidence_list::IncidenceLis
flag = contraction_tree_nodes[1]

while !isempty(eo)
e = pop!(eo)
if haskey(incidence_list.e2v, e)
vs = incidence_list.e2v[e]
if length(vs) >= 2
sub_list = IncidenceList(Dict([v => incidence_list.v2e[v] for v in vs]); openedges=incidence_list.openedges)
sub_tree, scs, tcs = tree_greedy(sub_list, log2_edge_sizes; nrepeat=nrepeat, α=α, temperature=temperature)
vi = contract_tree!(incidence_list, sub_tree, log2_edge_sizes, scs, tcs)
contraction_tree_nodes[tensors_list[vi]] = st2ct(sub_tree, tensors_list, contraction_tree_nodes)
flag = vi
end
eliminated_vertices = pop!(eo) # e is a vector of vertices, which are eliminated at the same time
vs = unique!(vcat([incidence_list.e2v[ei] for ei in eliminated_vertices if haskey(incidence_list.e2v, ei)]...)) # the tensors to be contracted, since they are connected to the eliminated vertices
if length(vs) >= 2
sub_list_indices = unique!(vcat([incidence_list.v2e[v] for v in vs]...)) # the vertices connected to the tensors to be contracted
sub_list_open_indices = setdiff(sub_list_indices, eliminated_vertices) # the vertices connected to the tensors to be contracted but not eliminated
sub_list = IncidenceList(Dict([v => incidence_list.v2e[v] for v in vs]); openedges=sub_list_open_indices) # the subgraph of the contracted tensors
sub_tree, scs, tcs = tree_greedy(sub_list, log2_edge_sizes; nrepeat=nrepeat, α=α, temperature=temperature) # optmize the subgraph with greedy method
vi = contract_tree!(incidence_list, sub_tree, log2_edge_sizes, scs, tcs) # insert the contracted tensors back to the total graph
contraction_tree_nodes[tensors_list[vi]] = st2ct(sub_tree, tensors_list, contraction_tree_nodes)
flag = vi
end
end

Expand Down
6 changes: 6 additions & 0 deletions test/treewidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,10 @@ using Test, Random
cc = contraction_complexity(optcode, size_dict)
@test cc.sc == 7
@test decorate(eincode)(tensors...) decorate(optcode)(tensors...)

eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e'], ['e'], ['f'], Char[]], ['a', 'f'])
tensors = tensors [fill(2.0,())]
optcode = optimize_exact_treewidth(optimizer, eincode, size_dict)
cc = contraction_complexity(optcode, size_dict)
@test decorate(eincode)(tensors...) decorate(optcode)(tensors...)
end

0 comments on commit 36e0250

Please sign in to comment.