Skip to content

Commit

Permalink
compatable with TreeWidthSolver v0.3 (#47)
Browse files Browse the repository at this point in the history
* compatable with TreeWidthSolver v0.3

* update docs

* support julia v1.8

* update ci
  • Loading branch information
ArrogantGao authored Aug 16, 2024
1 parent 954fc10 commit 0d81c4e
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 42 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.8'
- '1'
os:
- ubuntu-latest
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ KaHyPar = "0.3"
StatsBase = "0.34"
Suppressor = "0.2"
LuxorGraphPlot = "0.5.1"
TreeWidthSolver = "0.2"
julia = "1.9"
TreeWidthSolver = "0.3.1"
julia = "1.8"

[extras]
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
12 changes: 5 additions & 7 deletions src/treewidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
"""
exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; α::TA = 0.0, temperature::TT = 0.0, nrepeat=1) where {VT,ET,TA,TT}
This function calculates the exact treewidth of a graph using TreeWidthSolver.jl. It takes an incidence list representation of the graph (`incidence_list`) and a dictionary of logarithm base 2 edge sizes (`log2_edge_sizes`) as input. The function also accepts optional parameters `α`, `temperature`, and `nrepeat` with default values of 0.0, 0.0, and 1 respectively, which are parameter of the GreedyMethod used in the contraction process.
This function calculates the exact treewidth of a graph using TreeWidthSolver.jl. It takes an incidence list representation of the graph (`incidence_list`) and a dictionary of logarithm base 2 edge sizes (`log2_edge_sizes`) as input. The function also accepts optional parameters `α`, `temperature`, and `nrepeat` with default values of 0.0, 0.0, and 1 respectively, which are parameter of the GreedyMethod used in the contraction process as a sub optimizer.
## Arguments
- `incidence_list`: An incidence list representation of the graph.
Expand Down Expand Up @@ -67,10 +67,8 @@ function exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_
lg = induced_subgraph(line_graph, vertice_ids)[1]
lg_indicies = indicies[vertice_ids]
lg_weights = weights[vertice_ids]
labeled_graph = LabeledSimpleGraph(lg, lg_indicies, lg_weights)
tree_decomposition = exact_treewidth(labeled_graph)
elimination_order = EliminationOrder(tree_decomposition.tree)
contraction_tree = eo2ct(elimination_order, incidence_list, log2_edge_sizes, α, temperature, nrepeat)
eo = elimination_order(lg, labels = lg_indicies, weights = lg_weights)
contraction_tree = eo2ct(eo, incidence_list, log2_edge_sizes, α, temperature, nrepeat)
push!(contraction_trees, contraction_tree)
end

Expand All @@ -95,8 +93,8 @@ function il2lg(incidence_list::IncidenceList{VT, ET}, indicies::Vector{ET}) wher
end

# transform elimination order to contraction tree
function eo2ct(elimination_order::EliminationOrder, incidence_list::IncidenceList{VT, ET}, log2_edge_sizes, α::TA, temperature::TT, nrepeat) where {VT, ET, TA, TT}
eo = copy(elimination_order.order)
function eo2ct(elimination_order::Vector{Vector{TL}}, incidence_list::IncidenceList{VT, ET}, log2_edge_sizes, α::TA, temperature::TT, nrepeat) where {TL, VT, ET, TA, TT}
eo = copy(elimination_order)
incidence_list = copy(incidence_list)
contraction_tree_nodes = Vector{Union{VT, ContractionTree}}(collect(keys(incidence_list.v2e)))
tensors_list = Dict{VT, Int}()
Expand Down
76 changes: 48 additions & 28 deletions test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ using OMEinsum
xs = [[randn(2,2) for i=1:150]..., [randn(2) for i=1:100]...]

results = Float64[]
for optimizer in [TreeSA(ntrials=1), TreeSA(ntrials=1, nslices=5), GreedyMethod(), KaHyParBipartite(sc_target=18), SABipartite(sc_target=18, ntrials=1)]
for optimizer in [TreeSA(ntrials=1), TreeSA(ntrials=1, nslices=5), GreedyMethod(), SABipartite(sc_target=18, ntrials=1)]
for simplifier in (nothing, MergeVectors(), MergeGreedy())
@info "optimizer = $(optimizer), simplifier = $(simplifier)"
res = optimize_code(code,uniformsize(code, 2), optimizer, simplifier)
tc, sc = OMEinsum.timespace_complexity(res, uniformsize(code, 2))
@test sc <= 18
push!(results, res(xs...)[])
end
end
if isdefined(Base, :get_extension)
optimizer = KaHyParBipartite(sc_target=18)
for simplifier in (nothing, MergeVectors(), MergeGreedy())
@info "optimizer = $(optimizer), simplifier = $(simplifier)"
res = optimize_code(code,uniformsize(code, 2), optimizer, simplifier)
Expand All @@ -32,7 +42,16 @@ using OMEinsum
xs = [[randn(2,2) for i=1:15]..., [randn(2) for i=1:10]...]

results = Float64[]
for optimizer in [TreeSA(ntrials=1), TreeSA(ntrials=1, nslices=5), GreedyMethod(), KaHyParBipartite(sc_target=18), SABipartite(sc_target=18, ntrials=1), ExactTreewidth()]
for optimizer in [TreeSA(ntrials=1), TreeSA(ntrials=1, nslices=5), GreedyMethod(), SABipartite(sc_target=18, ntrials=1), ExactTreewidth()]
for simplifier in (nothing, MergeVectors(), MergeGreedy())
@info "optimizer = $(optimizer), simplifier = $(simplifier)"
res = optimize_code(small_code,uniformsize(small_code, 2), optimizer, simplifier)
tc, sc = OMEinsum.timespace_complexity(res, uniformsize(small_code, 2))
push!(results, res(xs...)[])
end
end
if isdefined(Base, :get_extension)
optimizer = KaHyParBipartite(sc_target=18)
for simplifier in (nothing, MergeVectors(), MergeGreedy())
@info "optimizer = $(optimizer), simplifier = $(simplifier)"
res = optimize_code(small_code,uniformsize(small_code, 2), optimizer, simplifier)
Expand All @@ -53,7 +72,7 @@ end
@test optimize_code(code, sizes, GreedyMethod()) == sne
@test optimize_code(code, sizes, TreeSA()) == SlicedEinsum(Char[], dne)
@test optimize_code(code, sizes, TreeSA(nslices=2)) == SlicedEinsum(Char[], dne)
@test optimize_code(code, sizes, KaHyParBipartite(sc_target=25)) == dne
isdefined(Base, :get_extension) && (@test optimize_code(code, sizes, KaHyParBipartite(sc_target=25)) == dne)
@test optimize_code(code, sizes, SABipartite(sc_target=25)) == dne
end

Expand All @@ -79,32 +98,33 @@ end
@test 10 * 2^sc2 > pm2 > 2^sc2
end

if isdefined(Base, :get_extension)
@testset "kahypar regression test" begin
code = ein"i->"
optcode = optimize_code(code, Dict('i'=>4), KaHyParBipartite(; sc_target=10, max_group_size=10))
@test optcode isa NestedEinsum
x = randn(4)
@test optcode(x) code(x)

@testset "kahypar regression test" begin
code = ein"i->"
optcode = optimize_code(code, Dict('i'=>4), KaHyParBipartite(; sc_target=10, max_group_size=10))
@test optcode isa NestedEinsum
x = randn(4)
@test optcode(x) code(x)

code = ein"i,j->"
optcode = optimize_code(code, Dict('i'=>4, 'j'=>4), KaHyParBipartite(; sc_target=10, max_group_size=10))
@test optcode isa NestedEinsum
x = randn(4)
y = randn(4)
@test optcode(x, y) code(x, y)
code = ein"i,j->"
optcode = optimize_code(code, Dict('i'=>4, 'j'=>4), KaHyParBipartite(; sc_target=10, max_group_size=10))
@test optcode isa NestedEinsum
x = randn(4)
y = randn(4)
@test optcode(x, y) code(x, y)

code = ein"ij,jk,kl->ijl"
println(code)
optcode = optimize_code(code, Dict('i'=>4, 'j'=>4, 'k'=>4, 'l'=>4), KaHyParBipartite(; sc_target=4, max_group_size=2))
println(optcode)
@test optcode isa NestedEinsum
a, b, c = [rand(4,4) for i=1:4]
@test optcode(a, b, c) code(a, b, c)
code = ein"ij,jk,kl->ijl"
println(code)
optcode = optimize_code(code, Dict('i'=>4, 'j'=>4, 'k'=>4, 'l'=>4), KaHyParBipartite(; sc_target=4, max_group_size=2))
println(optcode)
@test optcode isa NestedEinsum
a, b, c = [rand(4,4) for i=1:4]
@test optcode(a, b, c) code(a, b, c)

code = ein"ij,jk,kl->ijl"
optcode = optimize_code(code, Dict('i'=>3, 'j'=>3, 'k'=>3, 'l'=>3), KaHyParBipartite(; sc_target=4, max_group_size=2))
@test optcode isa NestedEinsum
a, b, c = [rand(3,3) for i=1:4]
@test optcode(a, b, c) code(a, b, c)
code = ein"ij,jk,kl->ijl"
optcode = optimize_code(code, Dict('i'=>3, 'j'=>3, 'k'=>3, 'l'=>3), KaHyParBipartite(; sc_target=4, max_group_size=2))
@test optcode isa NestedEinsum
a, b, c = [rand(3,3) for i=1:4]
@test optcode(a, b, c) code(a, b, c)
end
end
12 changes: 8 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ end
include("sa.jl")
end

@testset "kahypar" begin
include("kahypar.jl")
if isdefined(Base, :get_extension)
@testset "kahypar" begin
include("kahypar.jl")
end
end

@testset "treesa" begin
Expand All @@ -38,6 +40,8 @@ end
end

# testing the extension `LuxorTensorPlot` for visualization
@testset "visualization" begin
include("visualization.jl")
if isdefined(Base, :get_extension)
@testset "visualization" begin
include("visualization.jl")
end
end
2 changes: 1 addition & 1 deletion test/treewidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Test, Random
cc = contraction_complexity(optcode, size_dict)
# test flop
@test cc.tc log2(flop(optcode, size_dict))
@test 16 <= cc.tc <= log2(exp2(10)+exp2(16)+exp2(15)+exp2(9))
@test (16 <= cc.tc <= log2(exp2(10)+exp2(16)+exp2(15)+exp2(9))) | (cc.tc log2(exp2(10)+exp2(16)+exp2(15)+exp2(9)))
@test cc.sc == 11
@test decorate(eincode)(tensors...) decorate(optcode)(tensors...)

Expand Down

0 comments on commit 0d81c4e

Please sign in to comment.