Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a solver based on exact tree width solver #43

Merged
merged 12 commits into from
Aug 3, 2024
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ version = "0.8.3"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TreeWidthSolver = "7d267fc5-9ace-409f-a54c-cd2374872a55"

[weakdeps]
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
Expand All @@ -22,6 +24,7 @@ JSON = "0.21"
KaHyPar = "0.3"
StatsBase = "0.34"
Suppressor = "0.2"
TreeWidthSolver = "0.1.0"
julia = "1.9"

[extras]
Expand Down
79 changes: 79 additions & 0 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,85 @@ connector(::Type{Char}) = ""
connector(::Type{Int}) = "∘"
connector(::Type) = "-"

function is_binary_tree(code::NestedEinsum)
if isleaf(code) return true end
if length(code.args) > 2 return false end
return all(is_binary_tree, code.args)
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
end


# reformulate the nested einsum, removing a given tensor without change the space complexity
# consider only binary contraction tree with no openedges
function pivot_tree(code::NestedEinsum, removed_tensor_id::Int)

try @assert is_binary_tree(code) catch
error("The contraction tree is not binary")
end

try @assert isempty(getiyv(code)) catch
error("The contraction tree has open edges")
end
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved

path = path_to_tensor(code, removed_tensor_id)

right = popfirst!(path)
left = right == 1 ? 2 : 1

if isleaf(code.args[right])
return NestedEinsum([code.args[left].args...], EinCode(getixsv(code.args[left].eins), getixsv(code.eins)[right]))
else
# update the ein code to make sure the root of the left part and the right part are the same
left_code = code.args[left]
right_code = NestedEinsum([code.args[right].args...], EinCode(getixsv(code.args[right].eins), getixsv(code.eins)[left]))
end
tree = _pivot_tree!(left_code, right_code, path)

return tree
end


function _pivot_tree!(left_code::NestedEinsum{LT}, right_code::NestedEinsum{LT}, path::Vector{Int}) where{LT}
if !isleaf(right_code)
right = popfirst!(path)
left = right == 1 ? 2 : 1
if length(right_code.args) == 1
# orign: left: a, right: b -> a
# reformulated: left: a -> b, right: b
new_eins = EinCode([getiyv(right_code.eins)], getixsv(right_code.eins)[1])
left_code = NestedEinsum([left_code], new_eins)
left_code = _pivot_tree!(left_code, right_code.args[1], path)
elseif length(right_code.args) == 2
# origin: left: a, right: b, c -> a
# reformulated: left: a, b -> c, right: c
new_eins = EinCode([getiyv(right_code.eins), getixsv(right_code.eins)[left]], getixsv(right_code.eins)[right])
left_code = NestedEinsum([left_code, right_code.args[left]], new_eins)
left_code = _pivot_tree!(left_code, right_code.args[right], path)
else
error("The contraction tree is not binary")
end
end
return left_code
end

# find the path to a given tensor in a nested einsum
function path_to_tensor(code::NestedEinsum, index::Int)
path = Vector{Int}()
_find_root!(code, index, path)
return path
end

function _find_root!(code::NestedEinsum, index::Int, path::Vector{Int})
if isleaf(code) return code.tensorindex == index end

for (i, arg) in enumerate(code.args)
if _find_root!(arg, index, path)
pushfirst!(path, i)
return true
end
end
return false
end

############### Simplifier and optimizer types #################
abstract type CodeSimplifier end

Expand Down
5 changes: 4 additions & 1 deletion src/OMEinsumContractionOrders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Base.Threads
using AbstractTrees

export CodeOptimizer, CodeSimplifier,
KaHyParBipartite, GreedyMethod, TreeSA, SABipartite,
KaHyParBipartite, GreedyMethod, TreeSA, SABipartite, ExactTreewidth,
MergeGreedy, MergeVectors,
uniformsize,
simplify_code, optimize_code, optimize_permute,
Expand All @@ -31,6 +31,9 @@ include("kahypar.jl")
# local search method
include("treesa.jl")

# tree width method
include("treewidth.jl")

# simplification passes
include("simplify.jl")

Expand Down
5 changes: 4 additions & 1 deletion src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Returns a `NestedEinsum` instance. Input arguments are

* `eincode` is an einsum contraction code instance, one of `DynamicEinCode`, `StaticEinCode` or `NestedEinsum`.
* `size` is a dictionary of "edge label=>edge size" that contains the size information, one can use `uniformsize(eincode, 2)` to create a uniform size.
* `optimizer` is a `CodeOptimizer` instance, should be one of `GreedyMethod`, `KaHyParBipartite`, `SABipartite` or `TreeSA`. Check their docstrings for details.
* `optimizer` is a `CodeOptimizer` instance, should be one of `GreedyMethod`, `ExactTreewidth`, `KaHyParBipartite`, `SABipartite` or `TreeSA`. Check their docstrings for details.
* `simplifier` is one of `MergeVectors` or `MergeGreedy`.
* optimize the permutation if `permute` is true.

Expand Down Expand Up @@ -53,6 +53,9 @@ end
function _optimize_code(code, size_dict, optimizer::GreedyMethod)
optimize_greedy(code, size_dict; α = optimizer.α, temperature = optimizer.temperature, nrepeat=optimizer.nrepeat)
end
function _optimize_code(code, size_dict, optimizer::ExactTreewidth)
optimize_exact_treewidth(optimizer, code, size_dict)
end
function _optimize_code(code, size_dict, optimizer::SABipartite)
recursive_bipartite_optimize(optimizer, code, size_dict)
end
Expand Down
16 changes: 13 additions & 3 deletions src/kahypar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,21 @@ end

function recursive_bipartite_optimize(bipartiter, code::EinCode, size_dict)
ixs, iy = getixsv(code), getiyv(code)
ixv = [ixs..., iy]
if isempty(iy)
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
ixv = [ixs..., iy]
else
ixv = [ixs..., iy, empty(iy)]
end

adj, edges = adjacency_matrix(ixv)
vertices=collect(1:length(ixs))
vertices=collect(1:length(ixv) - 1)
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
parts = bipartition_recursive(bipartiter, adj, vertices, [log2(size_dict[e]) for e in edges])
recursive_construct_nestedeinsum(ixv, iy, parts, size_dict, 0, bipartiter.sub_optimizer)
optcode = recursive_construct_nestedeinsum(ixv, empty(iy), parts, size_dict, 0, bipartiter.sub_optimizer)
if isempty(iy)
return optcode
else
return pivot_tree(optcode, length(ixs) + 1)
end
end

maplocs(ne::NestedEinsum{ET}, parts) where ET = isleaf(ne) ? NestedEinsum{ET}(parts[ne.tensorindex]) : NestedEinsum(maplocs.(ne.args, Ref(parts)), ne.eins)
Expand Down
166 changes: 166 additions & 0 deletions src/treewidth.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
using TreeWidthSolver
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
using Graphs: connected_components, induced_subgraph, SimpleGraph, add_edge!

"""
struct ExactTreewidth{GM} <: CodeOptimizer
ExactTreewidth(greedy_config::GM = GreedyMethod(nrepeat=1))

A optimizer using the exact tree width solver proved in TreeWidthSolver.jl, the greedy_config is the configuration for the greedy method, which is used to solve the subproblems in the tree decomposition.

# Fields
- `greedy_config::GM`: The configuration for the greedy method.

"""
Base.@kwdef struct ExactTreewidth{GM} <: CodeOptimizer
greedy_config::GM = GreedyMethod(nrepeat=1)
end

"""
exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; α::TA = 0.0, temperature::TT = 0.0, nrepeat=1) where {VT,ET,TA,TT}

This function calculates the exact treewidth of a graph using TreeWidthSolver.jl. It takes an incidence list representation of the graph (`incidence_list`) and a dictionary of logarithm base 2 edge sizes (`log2_edge_sizes`) as input. The function also accepts optional parameters `α`, `temperature`, and `nrepeat` with default values of 0.0, 0.0, and 1 respectively, which are parameter of the GreedyMethod used in the contraction process.

## Arguments
- `incidence_list`: An incidence list representation of the graph.
- `log2_edge_sizes`: A dictionary of logarithm base 2 edge sizes.

## Returns
- The function returns a `ContractionTree` representing the contraction process.

```
julia> optimizer = OMEinsumContractionOrders.ExactTreewidth()
OMEinsumContractionOrders.ExactTreewidth{GreedyMethod{Float64, Float64}}(GreedyMethod{Float64, Float64}(0.0, 0.0, 1))

julia> eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])
ab, acd, bcef, e, df -> a

julia> size_dict = Dict([c=>(1<<i) for (i,c) in enumerate(['a', 'b', 'c', 'd', 'e', 'f'])]...)
Dict{Char, Int64} with 6 entries:
'f' => 64
'a' => 2
'c' => 8
'd' => 16
'e' => 32
'b' => 4

julia> optcode = optimize_code(eincode, size_dict, optimizer)
ab, ab -> a
├─ ab
└─ fac, bcf -> ab
├─ df, acd -> fac
│ ├─ df
│ └─ acd
└─ e, bcef -> bcf
├─ e
└─ bcef
```

"""
function exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; α::TA = 0.0, temperature::TT = 0.0, nrepeat=1) where {VT,ET,TA,TT}
indicies = collect(keys(incidence_list.e2v))
weights = [log2_edge_sizes[e] for e in indicies]
line_graph = il2lg(incidence_list, indicies)

contraction_trees = Vector{Union{ContractionTree, VT}}()

# avoid the case that the line graph is not connected
for vertice_ids in connected_components(line_graph)
lg = induced_subgraph(line_graph, vertice_ids)[1]
lg_indicies = indicies[vertice_ids]
lg_weights = weights[vertice_ids]
labeled_graph = LabeledSimpleGraph(lg, lg_indicies, lg_weights)
tree_decomposition = exact_treewidth(labeled_graph)
elimination_order = EliminationOrder(tree_decomposition.tree)
contraction_tree = eo2ct(elimination_order, incidence_list, log2_edge_sizes, α, temperature, nrepeat)
push!(contraction_trees, contraction_tree)
end

return reduce((x,y) -> ContractionTree(x, y), contraction_trees)
end

# transform incidence list to line graph
function il2lg(incidence_list::IncidenceList{VT, ET}, indicies::Vector{ET}) where {VT, ET}

line_graph = SimpleGraph(length(indicies))

for (i, e) in enumerate(indicies)
for v in incidence_list.e2v[e]
for ej in incidence_list.v2e[v]
if e != ej add_edge!(line_graph, i, findfirst(==(ej), indicies)) end
end
end
end

return line_graph
end

# transform elimination order to contraction tree
function eo2ct(elimination_order::EliminationOrder, incidence_list::IncidenceList{VT, ET}, log2_edge_sizes, α::TA, temperature::TT, nrepeat) where {VT, ET, TA, TT}
eo = copy(elimination_order.order)
incidence_list = copy(incidence_list)
contraction_tree_nodes = Vector{Union{VT, ContractionTree}}(collect(keys(incidence_list.v2e)))
tensors_list = Dict{VT, Int}()
for (i, v) in enumerate(contraction_tree_nodes)
tensors_list[v] = i
end

flag = contraction_tree_nodes[1]

while !isempty(eo)
e = pop!(eo)
if haskey(incidence_list.e2v, e)
vs = incidence_list.e2v[e]
if length(vs) >= 2
sub_list = IncidenceList(Dict([v => incidence_list.v2e[v] for v in vs]); openedges=incidence_list.openedges)
sub_tree, scs, tcs = tree_greedy(sub_list, log2_edge_sizes; nrepeat=nrepeat, α=α, temperature=temperature)
vi = contract_tree!(incidence_list, sub_tree, log2_edge_sizes, scs, tcs)
contraction_tree_nodes[tensors_list[vi]] = st2ct(sub_tree, tensors_list, contraction_tree_nodes)
flag = vi
end
end
end

return contraction_tree_nodes[tensors_list[flag]]
end

function st2ct(sub_tree::Union{ContractionTree, VT}, tensors_list::Dict{VT, Int}, contraction_tree_nodes::Vector{Union{ContractionTree, VT}}) where{VT}
if sub_tree isa ContractionTree
return ContractionTree(st2ct(sub_tree.left, tensors_list, contraction_tree_nodes), st2ct(sub_tree.right, tensors_list, contraction_tree_nodes))
else
return contraction_tree_nodes[tensors_list[sub_tree]]
end
end

"""
optimize_exact_treewidth(optimizer, eincode, size_dict)

Optimizing the contraction order via solve the exact tree width of the line graph corresponding to the eincode and return a `NestedEinsum` object.
Check the docstring of `exact_treewidth_method` for detailed explaination of other input arguments.
"""
function optimize_exact_treewidth(optimizer::ExactTreewidth{GM}, code::EinCode{L}, size_dict::Dict) where {L,GM}
optimize_exact_treewidth(optimizer, getixsv(code), getiyv(code), size_dict)
end
function optimize_exact_treewidth(optimizer::ExactTreewidth{GM}, ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L,TI}) where {L, TI, GM}
if length(ixs) <= 2
return NestedEinsum(NestedEinsum{L}.(1:length(ixs)), EinCode(ixs, iy))
end
log2_edge_sizes = Dict{L,Float64}()
for (k, v) in size_dict
log2_edge_sizes[k] = log2(v)
end
# complete all open edges as a clique, connected with a dummy tensor
incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)] ∪ [(length(ixs) + 1 => iy)]))

α = optimizer.greedy_config.α
temperature = optimizer.greedy_config.temperature
nrepeat = optimizer.greedy_config.nrepeat
tree = exact_treewidth_method(incidence_list, log2_edge_sizes; α = α, temperature = temperature, nrepeat=nrepeat)

# remove the dummy tensor added for open edges
if isempty(iy)
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
return parse_eincode!(incidence_list, tree, 1:length(ixs))[2]
else
optcode = parse_eincode!(incidence_list, tree, 1:length(ixs) + 1)[2]
return pivot_tree(optcode, length(ixs) + 1)
end
end
27 changes: 27 additions & 0 deletions test/Core.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using OMEinsumContractionOrders, OMEinsum
using OMEinsumContractionOrders: pivot_tree, path_to_tensor

using Test

@testset "tree reformulate" begin
eincode = ein"((ik, jkl), ij), (lm, m) -> "
code = OMEinsum.rawcode(eincode)

size_dict = Dict([c=>(1<<i) for (i,c) in enumerate(['i', 'j', 'k', 'l', 'm'])]...)
tensor_labels = [['i', 'k'], ['j', 'k', 'l'], ['i', 'j'], ['l', 'm'], ['m']]
size_tensors = [log2(prod(size_dict[l] for l in tensor)) for tensor in tensor_labels]
tensors = [rand([size_dict[j] for j in tensor_labels[i]]...) for i in 1:5]
for tensor_index in 1:5

path = path_to_tensor(code, tensor_index)
tensor = reduce((x, y) -> x.args[y], path, init = code)
@test tensor.tensorindex == tensor_index

new_code = pivot_tree(code, tensor_index)
@test contraction_complexity(new_code, size_dict).sc == max(contraction_complexity(code, size_dict).sc, size_tensors[tensor_index])

closed_code = OMEinsumContractionOrders.NestedEinsum([new_code, tensor], OMEinsumContractionOrders.EinCode([OMEinsumContractionOrders.getiyv(new_code), tensor_labels[tensor_index]], Char[]))
new_eincode = OMEinsum.decorate(closed_code)
@test eincode(tensors...) ≈ new_eincode(tensors...)
end
end
16 changes: 16 additions & 0 deletions test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ using OMEinsum
for i=1:length(results)-1
@test results[i] ≈ results[i+1]
end

small_code = random_regular_eincode(10, 3)
xs = [[randn(2,2) for i=1:15]..., [randn(2) for i=1:10]...]

results = Float64[]
for optimizer in [TreeSA(ntrials=1), TreeSA(ntrials=1, nslices=5), GreedyMethod(), KaHyParBipartite(sc_target=18), SABipartite(sc_target=18, ntrials=1), ExactTreewidth()]
for simplifier in (nothing, MergeVectors(), MergeGreedy())
@info "optimizer = $(optimizer), simplifier = $(simplifier)"
res = optimize_code(small_code,uniformsize(small_code, 2), optimizer, simplifier)
tc, sc = OMEinsum.timespace_complexity(res, uniformsize(small_code, 2))
push!(results, res(xs...)[])
end
end
for i=1:length(results)-1
@test results[i] ≈ results[i+1]
end
end

@testset "corner case: smaller contraction orders" begin
Expand Down
Loading
Loading