diff --git a/Project.toml b/Project.toml index 613aa19..17c5af9 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ KaHyPar = "0.3" StatsBase = "0.34" Suppressor = "0.2" LuxorGraphPlot = "0.5.1" -TreeWidthSolver = "0.1.0" +TreeWidthSolver = "0.2" julia = "1.9" [extras] diff --git a/src/treewidth.jl b/src/treewidth.jl index 74cfb7c..48659e0 100644 --- a/src/treewidth.jl +++ b/src/treewidth.jl @@ -55,9 +55,11 @@ ab, ab -> a """ function exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; α::TA = 0.0, temperature::TT = 0.0, nrepeat=1) where {VT,ET,TA,TT} indicies = collect(keys(incidence_list.e2v)) + tensors = collect(keys(incidence_list.v2e)) weights = [log2_edge_sizes[e] for e in indicies] line_graph = il2lg(incidence_list, indicies) + scalars = [i for i in tensors if isempty(incidence_list.v2e[i])] contraction_trees = Vector{Union{ContractionTree, VT}}() # avoid the case that the line graph is not connected @@ -71,8 +73,9 @@ function exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_ contraction_tree = eo2ct(elimination_order, incidence_list, log2_edge_sizes, α, temperature, nrepeat) push!(contraction_trees, contraction_tree) end - - return reduce((x,y) -> ContractionTree(x, y), contraction_trees) + + # add the scalars back to the contraction tree + return reduce((x,y) -> ContractionTree(x, y), contraction_trees ∪ scalars) end # transform incidence list to line graph @@ -104,16 +107,16 @@ function eo2ct(elimination_order::EliminationOrder, incidence_list::IncidenceLis flag = contraction_tree_nodes[1] while !isempty(eo) - e = pop!(eo) - if haskey(incidence_list.e2v, e) - vs = incidence_list.e2v[e] - if length(vs) >= 2 - sub_list = IncidenceList(Dict([v => incidence_list.v2e[v] for v in vs]); openedges=incidence_list.openedges) - sub_tree, scs, tcs = tree_greedy(sub_list, log2_edge_sizes; nrepeat=nrepeat, α=α, temperature=temperature) - vi = contract_tree!(incidence_list, sub_tree, log2_edge_sizes, scs, tcs) - contraction_tree_nodes[tensors_list[vi]] = st2ct(sub_tree, tensors_list, contraction_tree_nodes) - flag = vi - end + eliminated_vertices = pop!(eo) # e is a vector of vertices, which are eliminated at the same time + vs = unique!(vcat([incidence_list.e2v[ei] for ei in eliminated_vertices if haskey(incidence_list.e2v, ei)]...)) # the tensors to be contracted, since they are connected to the eliminated vertices + if length(vs) >= 2 + sub_list_indices = unique!(vcat([incidence_list.v2e[v] for v in vs]...)) # the vertices connected to the tensors to be contracted + sub_list_open_indices = setdiff(sub_list_indices, eliminated_vertices) # the vertices connected to the tensors to be contracted but not eliminated + sub_list = IncidenceList(Dict([v => incidence_list.v2e[v] for v in vs]); openedges=sub_list_open_indices) # the subgraph of the contracted tensors + sub_tree, scs, tcs = tree_greedy(sub_list, log2_edge_sizes; nrepeat=nrepeat, α=α, temperature=temperature) # optmize the subgraph with greedy method + vi = contract_tree!(incidence_list, sub_tree, log2_edge_sizes, scs, tcs) # insert the contracted tensors back to the total graph + contraction_tree_nodes[tensors_list[vi]] = st2ct(sub_tree, tensors_list, contraction_tree_nodes) + flag = vi end end diff --git a/test/treewidth.jl b/test/treewidth.jl index 901e941..79b724e 100644 --- a/test/treewidth.jl +++ b/test/treewidth.jl @@ -34,4 +34,10 @@ using Test, Random cc = contraction_complexity(optcode, size_dict) @test cc.sc == 7 @test decorate(eincode)(tensors...) ≈ decorate(optcode)(tensors...) + + eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e'], ['e'], ['f'], Char[]], ['a', 'f']) + tensors = tensors ∪ [fill(2.0,())] + optcode = optimize_exact_treewidth(optimizer, eincode, size_dict) + cc = contraction_complexity(optcode, size_dict) + @test decorate(eincode)(tensors...) ≈ decorate(optcode)(tensors...) end \ No newline at end of file