Skip to content

Commit

Permalink
Added a solver based on exact tree width solver (#43)
Browse files Browse the repository at this point in the history
* add exact tree width solver

* using reduce for vector of ContractionTree

* add tree_reformulate

* add tree reformulation for tree width

* add tree reformulator

* change tree_reformulate as pivot_tree

* add compat for TreeWidthSolver

* add test of ExactTreeWidth interface

* remove Graphs from deps by using TreeWidthSolver.Graphs

* update treewidth pivot

* fix a few bugs

* fix function name

---------

Co-authored-by: GiggleLiu <[email protected]>
  • Loading branch information
ArrogantGao and GiggleLiu authored Aug 3, 2024
1 parent 3031953 commit 2be8ee9
Show file tree
Hide file tree
Showing 13 changed files with 379 additions and 13 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TreeWidthSolver = "7d267fc5-9ace-409f-a54c-cd2374872a55"

[weakdeps]
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
Expand All @@ -22,6 +23,7 @@ JSON = "0.21"
KaHyPar = "0.3"
StatsBase = "0.34"
Suppressor = "0.2"
TreeWidthSolver = "0.1.0"
julia = "1.9"

[extras]
Expand Down
77 changes: 77 additions & 0 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,83 @@ connector(::Type{Char}) = ""
connector(::Type{Int}) = ""
connector(::Type) = "-"

function is_unary_or_binary(code::NestedEinsum)
if isleaf(code) return true end
if length(code.args) > 2 return false end
return all(is_unary_or_binary, code.args)
end


# reformulate the nested einsum, removing a given tensor without change the space complexity
# consider only binary contraction tree with no openedges
function pivot_tree(code::NestedEinsum{LT}, removed_tensor_id::Int) where LT
@assert is_unary_or_binary(code) "The contraction tree is not binary"
@assert isempty(getiyv(code)) "The contraction tree has open edges"

path = path_to_tensor(code, removed_tensor_id)
isempty(path) && return code # the tensor is at the root?

right = popfirst!(path)
left = right == 1 ? 2 : 1

if isleaf(code.args[left]) && isleaf(code.args[right])
ixsv = getixsv(code.eins)
return NestedEinsum([code.args[left]], EinCode([ixsv[left]], ixsv[right]))
elseif isleaf(code.args[right])
return NestedEinsum([code.args[left].args...], EinCode(getixsv(code.args[left].eins), getixsv(code.eins)[right]))
else
# update the ein code to make sure the root of the left part and the right part are the same
left_code = code.args[left]
right_code = NestedEinsum([code.args[right].args...], EinCode(getixsv(code.args[right].eins), getixsv(code.eins)[left]))
end
tree = _pivot_tree!(left_code, right_code, path)

return tree
end


function _pivot_tree!(left_code::NestedEinsum{LT}, right_code::NestedEinsum{LT}, path::Vector{Int}) where{LT}
if !isleaf(right_code)
right = popfirst!(path)
left = right == 1 ? 2 : 1
if length(right_code.args) == 1
# orign: left: a, right: b -> a
# reformulated: left: a -> b, right: b
new_eins = EinCode([getiyv(right_code.eins)], getixsv(right_code.eins)[1])
left_code = NestedEinsum([left_code], new_eins)
left_code = _pivot_tree!(left_code, right_code.args[1], path)
elseif length(right_code.args) == 2
# origin: left: a, right: b, c -> a
# reformulated: left: a, b -> c, right: c
new_eins = EinCode([getiyv(right_code.eins), getixsv(right_code.eins)[left]], getixsv(right_code.eins)[right])
left_code = NestedEinsum([left_code, right_code.args[left]], new_eins)
left_code = _pivot_tree!(left_code, right_code.args[right], path)
else
error("The contraction tree is not binary")
end
end
return left_code
end

# find the path to a given tensor in a nested einsum
function path_to_tensor(code::NestedEinsum, index::Int)
path = Vector{Int}()
_find_root!(code, index, path)
return path
end

function _find_root!(code::NestedEinsum, index::Int, path::Vector{Int})
if isleaf(code) return code.tensorindex == index end

for (i, arg) in enumerate(code.args)
if _find_root!(arg, index, path)
pushfirst!(path, i)
return true
end
end
return false
end

############### Simplifier and optimizer types #################
abstract type CodeSimplifier end

Expand Down
7 changes: 6 additions & 1 deletion src/OMEinsumContractionOrders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ using StatsBase
using Base: RefValue
using Base.Threads
using AbstractTrees
using TreeWidthSolver
using TreeWidthSolver.Graphs

export CodeOptimizer, CodeSimplifier,
KaHyParBipartite, GreedyMethod, TreeSA, SABipartite,
KaHyParBipartite, GreedyMethod, TreeSA, SABipartite, ExactTreewidth,
MergeGreedy, MergeVectors,
uniformsize,
simplify_code, optimize_code, optimize_permute,
Expand All @@ -31,6 +33,9 @@ include("kahypar.jl")
# local search method
include("treesa.jl")

# tree width method
include("treewidth.jl")

# simplification passes
include("simplify.jl")

Expand Down
5 changes: 4 additions & 1 deletion src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Returns a `NestedEinsum` instance. Input arguments are
* `eincode` is an einsum contraction code instance, one of `DynamicEinCode`, `StaticEinCode` or `NestedEinsum`.
* `size` is a dictionary of "edge label=>edge size" that contains the size information, one can use `uniformsize(eincode, 2)` to create a uniform size.
* `optimizer` is a `CodeOptimizer` instance, should be one of `GreedyMethod`, `KaHyParBipartite`, `SABipartite` or `TreeSA`. Check their docstrings for details.
* `optimizer` is a `CodeOptimizer` instance, should be one of `GreedyMethod`, `ExactTreewidth`, `KaHyParBipartite`, `SABipartite` or `TreeSA`. Check their docstrings for details.
* `simplifier` is one of `MergeVectors` or `MergeGreedy`.
* optimize the permutation if `permute` is true.
Expand Down Expand Up @@ -53,6 +53,9 @@ end
function _optimize_code(code, size_dict, optimizer::GreedyMethod)
optimize_greedy(code, size_dict; α = optimizer.α, temperature = optimizer.temperature, nrepeat=optimizer.nrepeat)
end
function _optimize_code(code, size_dict, optimizer::ExactTreewidth)
optimize_exact_treewidth(optimizer, code, size_dict)
end
function _optimize_code(code, size_dict, optimizer::SABipartite)
recursive_bipartite_optimize(optimizer, code, size_dict)
end
Expand Down
2 changes: 1 addition & 1 deletion src/json.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function _todict(ne::SlicedEinsum)
end
function _todict(ne::NestedEinsum)
LT = labeltype(ne)
dict = Dict{String,Any}("label-type"=>LT, "inputs"=>getixsv(ne), "output"=>getiyv(ne))
dict = Dict{String,Any}("label-type"=>string(LT), "inputs"=>getixsv(ne), "output"=>getiyv(ne))
dict["tree"] = todict(ne)
return dict
end
Expand Down
8 changes: 5 additions & 3 deletions src/kahypar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function adjacency_matrix(ixs::AbstractVector)
push!(rows, map(x->i, ix)...)
push!(cols, map(x->findfirst(==(x), edges), ix)...)
end
return sparse(rows, cols, ones(Int, length(rows))), edges
return sparse(rows, cols, ones(Int, length(rows)), length(ixs), length(edges)), edges
end

# legacy interface
Expand All @@ -149,10 +149,12 @@ end
function recursive_bipartite_optimize(bipartiter, code::EinCode, size_dict)
ixs, iy = getixsv(code), getiyv(code)
ixv = [ixs..., iy]

adj, edges = adjacency_matrix(ixv)
vertices=collect(1:length(ixs))
vertices=collect(1:length(ixv))
parts = bipartition_recursive(bipartiter, adj, vertices, [log2(size_dict[e]) for e in edges])
recursive_construct_nestedeinsum(ixv, iy, parts, size_dict, 0, bipartiter.sub_optimizer)
optcode = recursive_construct_nestedeinsum(ixv, empty(iy), parts, size_dict, 0, bipartiter.sub_optimizer)
return pivot_tree(optcode, length(ixs) + 1)
end

maplocs(ne::NestedEinsum{ET}, parts) where ET = isleaf(ne) ? NestedEinsum{ET}(parts[ne.tensorindex]) : NestedEinsum(maplocs.(ne.args, Ref(parts)), ne.eins)
Expand Down
12 changes: 7 additions & 5 deletions src/sa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,18 @@ function partition_state(adj, group, config, log2_sizes)
end

function bipartite_sc(bipartiter::SABipartite, adj::SparseMatrixCSC, vertices, log2_sizes)
@assert length(vertices) >= 2
degrees_all = sum(adj, dims=1)
adjt = SparseMatrixCSC(adj')
config = _initialize(bipartiter.initializer,adj, vertices, log2_sizes)
config = _initialize(bipartiter.initializer, adj, vertices, log2_sizes)
if all(config .== 1) || all(config .== 2)
config[1] = 3 - config[1] # flip the first group to avoid empty group
end
best = partition_state(adj, vertices, config, log2_sizes) # this is the `state` of current partition.

for _ = 1:bipartiter.ntrials
config = _initialize(bipartiter.initializer,adj, vertices, log2_sizes)
state = partition_state(adj, vertices, config, log2_sizes) # this is the `state` of current partition.
if state.group_sizes[1]==0 || state.group_sizes[2] == 0
continue
end

@inbounds for β in bipartiter.βs, iter = 1:bipartiter.niters
idxi = rand(1:length(vertices))
Expand All @@ -78,6 +79,7 @@ function bipartite_sc(bipartiter::SABipartite, adj::SparseMatrixCSC, vertices, l
end
accept && update_state!(state, adjt, vertices, idxi, sc_ti, sc_tinew, newloss)
end
(state.group_sizes[1]==0 || state.group_sizes[2] == 0) && continue
tc, sc1, sc2 = timespace_complexity_singlestep(state.config, adj, vertices, log2_sizes)
@assert state.group_scs [sc1, sc2] # sanity check
if maximum(state.group_scs) <= max(bipartiter.sc_target, maximum(best.group_scs)) && (maximum(best.group_scs) >= bipartiter.sc_target || state.loss[] < best.loss[])
Expand All @@ -87,7 +89,7 @@ function bipartite_sc(bipartiter::SABipartite, adj::SparseMatrixCSC, vertices, l
best_tc, = timespace_complexity_singlestep(best.config, adj, vertices, log2_sizes)
@debug "best loss = $(round(best.loss[]; digits=3)) space complexities = $(best.group_scs) time complexity = $(best_tc) groups_sizes = $(best.group_sizes)"
if maximum(best.group_scs) > bipartiter.sc_target
@warn "target space complexity not found, got: $(maximum(best.group_scs)), with time complexity $best_tc."
@warn "target space complexity $(bipartiter.sc_target) not found, got: $(maximum(best.group_scs)), with time complexity $best_tc."
end
return vertices[findall(==(1), best.config)], vertices[findall(==(2), best.config)]
end
Expand Down
160 changes: 160 additions & 0 deletions src/treewidth.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
struct ExactTreewidth{GM} <: CodeOptimizer
ExactTreewidth(greedy_config::GM = GreedyMethod(nrepeat=1))
A optimizer using the exact tree width solver proved in TreeWidthSolver.jl, the greedy_config is the configuration for the greedy method, which is used to solve the subproblems in the tree decomposition.
# Fields
- `greedy_config::GM`: The configuration for the greedy method.
"""
Base.@kwdef struct ExactTreewidth{GM} <: CodeOptimizer
greedy_config::GM = GreedyMethod(nrepeat=1)
end

"""
exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; α::TA = 0.0, temperature::TT = 0.0, nrepeat=1) where {VT,ET,TA,TT}
This function calculates the exact treewidth of a graph using TreeWidthSolver.jl. It takes an incidence list representation of the graph (`incidence_list`) and a dictionary of logarithm base 2 edge sizes (`log2_edge_sizes`) as input. The function also accepts optional parameters `α`, `temperature`, and `nrepeat` with default values of 0.0, 0.0, and 1 respectively, which are parameter of the GreedyMethod used in the contraction process.
## Arguments
- `incidence_list`: An incidence list representation of the graph.
- `log2_edge_sizes`: A dictionary of logarithm base 2 edge sizes.
## Returns
- The function returns a `ContractionTree` representing the contraction process.
```
julia> optimizer = OMEinsumContractionOrders.ExactTreewidth()
OMEinsumContractionOrders.ExactTreewidth{GreedyMethod{Float64, Float64}}(GreedyMethod{Float64, Float64}(0.0, 0.0, 1))
julia> eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])
ab, acd, bcef, e, df -> a
julia> size_dict = Dict([c=>(1<<i) for (i,c) in enumerate(['a', 'b', 'c', 'd', 'e', 'f'])]...)
Dict{Char, Int64} with 6 entries:
'f' => 64
'a' => 2
'c' => 8
'd' => 16
'e' => 32
'b' => 4
julia> optcode = optimize_code(eincode, size_dict, optimizer)
ab, ab -> a
├─ ab
└─ fac, bcf -> ab
├─ df, acd -> fac
│ ├─ df
│ └─ acd
└─ e, bcef -> bcf
├─ e
└─ bcef
```
"""
function exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; α::TA = 0.0, temperature::TT = 0.0, nrepeat=1) where {VT,ET,TA,TT}
indicies = collect(keys(incidence_list.e2v))
weights = [log2_edge_sizes[e] for e in indicies]
line_graph = il2lg(incidence_list, indicies)

contraction_trees = Vector{Union{ContractionTree, VT}}()

# avoid the case that the line graph is not connected
for vertice_ids in connected_components(line_graph)
lg = induced_subgraph(line_graph, vertice_ids)[1]
lg_indicies = indicies[vertice_ids]
lg_weights = weights[vertice_ids]
labeled_graph = LabeledSimpleGraph(lg, lg_indicies, lg_weights)
tree_decomposition = exact_treewidth(labeled_graph)
elimination_order = EliminationOrder(tree_decomposition.tree)
contraction_tree = eo2ct(elimination_order, incidence_list, log2_edge_sizes, α, temperature, nrepeat)
push!(contraction_trees, contraction_tree)
end

return reduce((x,y) -> ContractionTree(x, y), contraction_trees)
end

# transform incidence list to line graph
function il2lg(incidence_list::IncidenceList{VT, ET}, indicies::Vector{ET}) where {VT, ET}

line_graph = SimpleGraph(length(indicies))

for (i, e) in enumerate(indicies)
for v in incidence_list.e2v[e]
for ej in incidence_list.v2e[v]
if e != ej add_edge!(line_graph, i, findfirst(==(ej), indicies)) end
end
end
end

return line_graph
end

# transform elimination order to contraction tree
function eo2ct(elimination_order::EliminationOrder, incidence_list::IncidenceList{VT, ET}, log2_edge_sizes, α::TA, temperature::TT, nrepeat) where {VT, ET, TA, TT}
eo = copy(elimination_order.order)
incidence_list = copy(incidence_list)
contraction_tree_nodes = Vector{Union{VT, ContractionTree}}(collect(keys(incidence_list.v2e)))
tensors_list = Dict{VT, Int}()
for (i, v) in enumerate(contraction_tree_nodes)
tensors_list[v] = i
end

flag = contraction_tree_nodes[1]

while !isempty(eo)
e = pop!(eo)
if haskey(incidence_list.e2v, e)
vs = incidence_list.e2v[e]
if length(vs) >= 2
sub_list = IncidenceList(Dict([v => incidence_list.v2e[v] for v in vs]); openedges=incidence_list.openedges)
sub_tree, scs, tcs = tree_greedy(sub_list, log2_edge_sizes; nrepeat=nrepeat, α=α, temperature=temperature)
vi = contract_tree!(incidence_list, sub_tree, log2_edge_sizes, scs, tcs)
contraction_tree_nodes[tensors_list[vi]] = st2ct(sub_tree, tensors_list, contraction_tree_nodes)
flag = vi
end
end
end

return contraction_tree_nodes[tensors_list[flag]]
end

function st2ct(sub_tree::Union{ContractionTree, VT}, tensors_list::Dict{VT, Int}, contraction_tree_nodes::Vector{Union{ContractionTree, VT}}) where{VT}
if sub_tree isa ContractionTree
return ContractionTree(st2ct(sub_tree.left, tensors_list, contraction_tree_nodes), st2ct(sub_tree.right, tensors_list, contraction_tree_nodes))
else
return contraction_tree_nodes[tensors_list[sub_tree]]
end
end

"""
optimize_exact_treewidth(optimizer, eincode, size_dict)
Optimizing the contraction order via solve the exact tree width of the line graph corresponding to the eincode and return a `NestedEinsum` object.
Check the docstring of `exact_treewidth_method` for detailed explaination of other input arguments.
"""
function optimize_exact_treewidth(optimizer::ExactTreewidth{GM}, code::EinCode{L}, size_dict::Dict) where {L,GM}
optimize_exact_treewidth(optimizer, getixsv(code), getiyv(code), size_dict)
end
function optimize_exact_treewidth(optimizer::ExactTreewidth{GM}, ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L,TI}) where {L, TI, GM}
if length(ixs) <= 2
return NestedEinsum(NestedEinsum{L}.(1:length(ixs)), EinCode(ixs, iy))
end
log2_edge_sizes = Dict{L,Float64}()
for (k, v) in size_dict
log2_edge_sizes[k] = log2(v)
end
# complete all open edges as a clique, connected with a dummy tensor
incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)] [(length(ixs) + 1 => iy)]))

α = optimizer.greedy_config.α
temperature = optimizer.greedy_config.temperature
nrepeat = optimizer.greedy_config.nrepeat
tree = exact_treewidth_method(incidence_list, log2_edge_sizes; α = α, temperature = temperature, nrepeat=nrepeat)

# remove the dummy tensor added for open edges
optcode = parse_eincode!(incidence_list, tree, 1:length(ixs) + 1)[2]

return pivot_tree(optcode, length(ixs) + 1)
end
Loading

0 comments on commit 2be8ee9

Please sign in to comment.