Skip to content

Commit

Permalink
switch to extensions from Requires
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Sep 23, 2023
1 parent 57923b1 commit 2c59909
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 36 deletions.
12 changes: 8 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,29 @@ version = "0.8.2"
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
BetterExp = "7cffe744-45fd-4178-b173-cf893948b8b7"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"

[weakdeps]
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"

[extensions]
KaHyParExt = ["KaHyPar"]

[compat]
AbstractTrees = "0.3, 0.4"
BetterExp = "0.1"
JSON = "0.21"
Requires = "1"
KaHyPar = "0.3"
Suppressor = "0.2"
julia = "1"

[extras]
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"

[targets]
test = ["Test", "Random", "KaHyPar", "Graphs", "TropicalNumbers", "OMEinsum"]
test = ["Test", "Random", "Graphs", "TropicalNumbers", "OMEinsum", "KaHyPar"]
30 changes: 30 additions & 0 deletions ext/KaHyParExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module KaHyParExt

using OMEinsumContractionOrders: KaHyParBipartite, SparseMatrixCSC, group_sc, induced_subhypergraph, convert2int
import KaHyPar
import OMEinsumContractionOrders: bipartite_sc
using Suppressor: @suppress

function bipartite_sc(bipartiter::KaHyParBipartite, adj::SparseMatrixCSC, vertices, log2_sizes)
n_v = length(vertices)
subgraph, remaining_edges = induced_subhypergraph(adj, vertices)
hypergraph = KaHyPar.HyperGraph(subgraph, ones(n_v), convert2int(log2_sizes[remaining_edges]))
local parts
min_sc = 999999
for imbalance in bipartiter.imbalances
parts = @suppress KaHyPar.partition(hypergraph, 2; imbalance=imbalance, configuration=:edge_cut)
part0 = vertices[parts .== 0]
part1 = vertices[parts .== 1]
sc0, sc1 = group_sc(adj, part0, log2_sizes), group_sc(adj, part1, log2_sizes)
sc = max(sc0, sc1)
min_sc = min(sc, min_sc)
@debug "imbalance $imbalance: sc = $sc, group = ($(length(part0)), $(length(part1)))"
if sc <= bipartiter.sc_target
return part0, part1
end
end
error("fail to find a valid partition for `sc_target = $(bipartiter.sc_target)`, got minimum value `$min_sc` (imbalances = $(bipartiter.imbalances))")
end

@info "`OMEinsumContractionOrders` loads `KaHyParExt` extension successfully."
end
9 changes: 0 additions & 9 deletions src/OMEinsumContractionOrders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,8 @@ using SparseArrays
using Base: RefValue
using BetterExp
using Base.Threads
using Suppressor: @suppress
using AbstractTrees

using Requires
function __init__()
@require KaHyPar="2a6221f6-aa48-11e9-3542-2d9e0ef01880" begin
using .KaHyPar
@info "`OMEinsumContractionOrders` loads `KaHyPar` module successfully."
end
end

export CodeOptimizer, CodeSimplifier,
KaHyParBipartite, GreedyMethod, TreeSA, SABipartite,
MinSpaceDiff, MinSpaceOut,
Expand Down
25 changes: 3 additions & 22 deletions src/kahypar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,9 @@ function convert2int(sizes::AbstractVector)
round.(Int, sizes .* 100)
end

function bipartite_sc(bipartiter::KaHyParBipartite, adj::SparseMatrixCSC, vertices, log2_sizes)
n_v = length(vertices)
subgraph, remaining_edges = induced_subhypergraph(adj, vertices)
if !isdefined(@__MODULE__, :KaHyPar)
error("Module `KaHyPar` not found, please type `using KaHyPar` before using the `KaHyParBipartite` optimizer!")
end
hypergraph = KaHyPar.HyperGraph(subgraph, ones(n_v), convert2int(log2_sizes[remaining_edges]))
local parts
min_sc = 999999
for imbalance in bipartiter.imbalances
parts = @suppress KaHyPar.partition(hypergraph, 2; imbalance=imbalance, configuration=:edge_cut)
part0 = vertices[parts .== 0]
part1 = vertices[parts .== 1]
sc0, sc1 = group_sc(adj, part0, log2_sizes), group_sc(adj, part1, log2_sizes)
sc = max(sc0, sc1)
min_sc = min(sc, min_sc)
@debug "imbalance $imbalance: sc = $sc, group = ($(length(part0)), $(length(part1)))"
if sc <= bipartiter.sc_target
return part0, part1
end
end
error("fail to find a valid partition for `sc_target = $(bipartiter.sc_target)`, got minimum value `$min_sc` (imbalances = $(bipartiter.imbalances))")
function bipartite_sc(bipartiter, adj::SparseMatrixCSC, vertices, log2_sizes)
error("""Guess you are trying to use the `KaHyParBipartite` optimizer.
Then you need to add `using KaHyPar` first!""")
end

# the space complexity (external degree of freedoms) if we contract this group
Expand Down
2 changes: 1 addition & 1 deletion src/treesa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function Base.replace!(slicer::Slicer, pair::Pair)
return slicer
end

function Base.push!(slicer, best)
function Base.push!(slicer::Slicer, best)
@assert length(slicer) < slicer.max_size
@assert !haskey(slicer.legs, best)
slicer.legs[best] = slicer.log2_sizes[best] # add best to legs
Expand Down

0 comments on commit 2c59909

Please sign in to comment.