Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compatable with TreeWidthSolver v0.3 #47

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.8'
- '1'
os:
- ubuntu-latest
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ KaHyPar = "0.3"
StatsBase = "0.34"
Suppressor = "0.2"
LuxorGraphPlot = "0.5.1"
TreeWidthSolver = "0.2"
julia = "1.9"
TreeWidthSolver = "0.3.1"
julia = "1.8"

[extras]
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
12 changes: 5 additions & 7 deletions src/treewidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
"""
exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; α::TA = 0.0, temperature::TT = 0.0, nrepeat=1) where {VT,ET,TA,TT}

This function calculates the exact treewidth of a graph using TreeWidthSolver.jl. It takes an incidence list representation of the graph (`incidence_list`) and a dictionary of logarithm base 2 edge sizes (`log2_edge_sizes`) as input. The function also accepts optional parameters `α`, `temperature`, and `nrepeat` with default values of 0.0, 0.0, and 1 respectively, which are parameter of the GreedyMethod used in the contraction process.
This function calculates the exact treewidth of a graph using TreeWidthSolver.jl. It takes an incidence list representation of the graph (`incidence_list`) and a dictionary of logarithm base 2 edge sizes (`log2_edge_sizes`) as input. The function also accepts optional parameters `α`, `temperature`, and `nrepeat` with default values of 0.0, 0.0, and 1 respectively, which are parameter of the GreedyMethod used in the contraction process as a sub optimizer.

## Arguments
- `incidence_list`: An incidence list representation of the graph.
Expand Down Expand Up @@ -67,10 +67,8 @@ function exact_treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_
lg = induced_subgraph(line_graph, vertice_ids)[1]
lg_indicies = indicies[vertice_ids]
lg_weights = weights[vertice_ids]
labeled_graph = LabeledSimpleGraph(lg, lg_indicies, lg_weights)
tree_decomposition = exact_treewidth(labeled_graph)
elimination_order = EliminationOrder(tree_decomposition.tree)
contraction_tree = eo2ct(elimination_order, incidence_list, log2_edge_sizes, α, temperature, nrepeat)
eo = elimination_order(lg, labels = lg_indicies, weights = lg_weights)
contraction_tree = eo2ct(eo, incidence_list, log2_edge_sizes, α, temperature, nrepeat)
push!(contraction_trees, contraction_tree)
end

Expand All @@ -95,8 +93,8 @@ function il2lg(incidence_list::IncidenceList{VT, ET}, indicies::Vector{ET}) wher
end

# transform elimination order to contraction tree
function eo2ct(elimination_order::EliminationOrder, incidence_list::IncidenceList{VT, ET}, log2_edge_sizes, α::TA, temperature::TT, nrepeat) where {VT, ET, TA, TT}
eo = copy(elimination_order.order)
function eo2ct(elimination_order::Vector{Vector{TL}}, incidence_list::IncidenceList{VT, ET}, log2_edge_sizes, α::TA, temperature::TT, nrepeat) where {TL, VT, ET, TA, TT}
eo = copy(elimination_order)
incidence_list = copy(incidence_list)
contraction_tree_nodes = Vector{Union{VT, ContractionTree}}(collect(keys(incidence_list.v2e)))
tensors_list = Dict{VT, Int}()
Expand Down
76 changes: 48 additions & 28 deletions test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ using OMEinsum
xs = [[randn(2,2) for i=1:150]..., [randn(2) for i=1:100]...]

results = Float64[]
for optimizer in [TreeSA(ntrials=1), TreeSA(ntrials=1, nslices=5), GreedyMethod(), KaHyParBipartite(sc_target=18), SABipartite(sc_target=18, ntrials=1)]
for optimizer in [TreeSA(ntrials=1), TreeSA(ntrials=1, nslices=5), GreedyMethod(), SABipartite(sc_target=18, ntrials=1)]
for simplifier in (nothing, MergeVectors(), MergeGreedy())
@info "optimizer = $(optimizer), simplifier = $(simplifier)"
res = optimize_code(code,uniformsize(code, 2), optimizer, simplifier)
tc, sc = OMEinsum.timespace_complexity(res, uniformsize(code, 2))
@test sc <= 18
push!(results, res(xs...)[])
end
end
if isdefined(Base, :get_extension)
optimizer = KaHyParBipartite(sc_target=18)
for simplifier in (nothing, MergeVectors(), MergeGreedy())
@info "optimizer = $(optimizer), simplifier = $(simplifier)"
res = optimize_code(code,uniformsize(code, 2), optimizer, simplifier)
Expand All @@ -32,7 +42,16 @@ using OMEinsum
xs = [[randn(2,2) for i=1:15]..., [randn(2) for i=1:10]...]

results = Float64[]
for optimizer in [TreeSA(ntrials=1), TreeSA(ntrials=1, nslices=5), GreedyMethod(), KaHyParBipartite(sc_target=18), SABipartite(sc_target=18, ntrials=1), ExactTreewidth()]
for optimizer in [TreeSA(ntrials=1), TreeSA(ntrials=1, nslices=5), GreedyMethod(), SABipartite(sc_target=18, ntrials=1), ExactTreewidth()]
for simplifier in (nothing, MergeVectors(), MergeGreedy())
@info "optimizer = $(optimizer), simplifier = $(simplifier)"
res = optimize_code(small_code,uniformsize(small_code, 2), optimizer, simplifier)
tc, sc = OMEinsum.timespace_complexity(res, uniformsize(small_code, 2))
push!(results, res(xs...)[])
end
end
if isdefined(Base, :get_extension)
optimizer = KaHyParBipartite(sc_target=18)
for simplifier in (nothing, MergeVectors(), MergeGreedy())
@info "optimizer = $(optimizer), simplifier = $(simplifier)"
res = optimize_code(small_code,uniformsize(small_code, 2), optimizer, simplifier)
Expand All @@ -53,7 +72,7 @@ end
@test optimize_code(code, sizes, GreedyMethod()) == sne
@test optimize_code(code, sizes, TreeSA()) == SlicedEinsum(Char[], dne)
@test optimize_code(code, sizes, TreeSA(nslices=2)) == SlicedEinsum(Char[], dne)
@test optimize_code(code, sizes, KaHyParBipartite(sc_target=25)) == dne
isdefined(Base, :get_extension) && (@test optimize_code(code, sizes, KaHyParBipartite(sc_target=25)) == dne)
@test optimize_code(code, sizes, SABipartite(sc_target=25)) == dne
end

Expand All @@ -79,32 +98,33 @@ end
@test 10 * 2^sc2 > pm2 > 2^sc2
end

if isdefined(Base, :get_extension)
@testset "kahypar regression test" begin
code = ein"i->"
optcode = optimize_code(code, Dict('i'=>4), KaHyParBipartite(; sc_target=10, max_group_size=10))
@test optcode isa NestedEinsum
x = randn(4)
@test optcode(x) ≈ code(x)

@testset "kahypar regression test" begin
code = ein"i->"
optcode = optimize_code(code, Dict('i'=>4), KaHyParBipartite(; sc_target=10, max_group_size=10))
@test optcode isa NestedEinsum
x = randn(4)
@test optcode(x) ≈ code(x)

code = ein"i,j->"
optcode = optimize_code(code, Dict('i'=>4, 'j'=>4), KaHyParBipartite(; sc_target=10, max_group_size=10))
@test optcode isa NestedEinsum
x = randn(4)
y = randn(4)
@test optcode(x, y) ≈ code(x, y)
code = ein"i,j->"
optcode = optimize_code(code, Dict('i'=>4, 'j'=>4), KaHyParBipartite(; sc_target=10, max_group_size=10))
@test optcode isa NestedEinsum
x = randn(4)
y = randn(4)
@test optcode(x, y) ≈ code(x, y)

code = ein"ij,jk,kl->ijl"
println(code)
optcode = optimize_code(code, Dict('i'=>4, 'j'=>4, 'k'=>4, 'l'=>4), KaHyParBipartite(; sc_target=4, max_group_size=2))
println(optcode)
@test optcode isa NestedEinsum
a, b, c = [rand(4,4) for i=1:4]
@test optcode(a, b, c) ≈ code(a, b, c)
code = ein"ij,jk,kl->ijl"
println(code)
optcode = optimize_code(code, Dict('i'=>4, 'j'=>4, 'k'=>4, 'l'=>4), KaHyParBipartite(; sc_target=4, max_group_size=2))
println(optcode)
@test optcode isa NestedEinsum
a, b, c = [rand(4,4) for i=1:4]
@test optcode(a, b, c) ≈ code(a, b, c)

code = ein"ij,jk,kl->ijl"
optcode = optimize_code(code, Dict('i'=>3, 'j'=>3, 'k'=>3, 'l'=>3), KaHyParBipartite(; sc_target=4, max_group_size=2))
@test optcode isa NestedEinsum
a, b, c = [rand(3,3) for i=1:4]
@test optcode(a, b, c) ≈ code(a, b, c)
code = ein"ij,jk,kl->ijl"
optcode = optimize_code(code, Dict('i'=>3, 'j'=>3, 'k'=>3, 'l'=>3), KaHyParBipartite(; sc_target=4, max_group_size=2))
@test optcode isa NestedEinsum
a, b, c = [rand(3,3) for i=1:4]
@test optcode(a, b, c) ≈ code(a, b, c)
end
end
12 changes: 8 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ end
include("sa.jl")
end

@testset "kahypar" begin
include("kahypar.jl")
if isdefined(Base, :get_extension)
@testset "kahypar" begin
include("kahypar.jl")
end
end

@testset "treesa" begin
Expand All @@ -38,6 +40,8 @@ end
end

# testing the extension `LuxorTensorPlot` for visualization
@testset "visualization" begin
include("visualization.jl")
if isdefined(Base, :get_extension)
@testset "visualization" begin
include("visualization.jl")
end
end
2 changes: 1 addition & 1 deletion test/treewidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Test, Random
cc = contraction_complexity(optcode, size_dict)
# test flop
@test cc.tc ≈ log2(flop(optcode, size_dict))
@test 16 <= cc.tc <= log2(exp2(10)+exp2(16)+exp2(15)+exp2(9))
@test (16 <= cc.tc <= log2(exp2(10)+exp2(16)+exp2(15)+exp2(9))) | (cc.tc ≈ log2(exp2(10)+exp2(16)+exp2(15)+exp2(9)))
@test cc.sc == 11
@test decorate(eincode)(tensors...) ≈ decorate(optcode)(tensors...)

Expand Down
Loading