Skip to content

Commit

Permalink
change tree_reformulate as pivot_tree
Browse files Browse the repository at this point in the history
  • Loading branch information
ArrogantGao committed Aug 1, 2024
1 parent b05adac commit a72fbaf
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ end

# reformulate the nested einsum, removing a given tensor without change the space complexity
# consider only binary contraction tree with no openedges
function tree_reformulate(code::NestedEinsum, removed_tensor_id::Int)
function pivot_tree(code::NestedEinsum, removed_tensor_id::Int)

try @assert is_binary_tree(code) catch
error("The contraction tree is not binary")
Expand All @@ -108,13 +108,13 @@ function tree_reformulate(code::NestedEinsum, removed_tensor_id::Int)
left_code = code.args[left]
right_code = NestedEinsum([code.args[right].args...], EinCode(getixsv(code.args[right].eins), getixsv(code.eins)[left]))
end
tree = _tree_reformulate!(left_code, right_code, path)
tree = _pivot_tree!(left_code, right_code, path)

return tree
end


function _tree_reformulate!(left_code::NestedEinsum{LT}, right_code::NestedEinsum{LT}, path::Vector{Int}) where{LT}
function _pivot_tree!(left_code::NestedEinsum{LT}, right_code::NestedEinsum{LT}, path::Vector{Int}) where{LT}
if !isleaf(right_code)
right = popfirst!(path)
left = right == 1 ? 2 : 1
Expand All @@ -123,13 +123,13 @@ function _tree_reformulate!(left_code::NestedEinsum{LT}, right_code::NestedEinsu
# reformulated: left: a -> b, right: b
new_eins = EinCode([getiyv(right_code.eins)], getixsv(right_code.eins)[1])
left_code = NestedEinsum([left_code], new_eins)
left_code = _tree_reformulate!(left_code, right_code.args[1], path)
left_code = _pivot_tree!(left_code, right_code.args[1], path)
elseif length(right_code.args) == 2
# origin: left: a, right: b, c -> a
# reformulated: left: a, b -> c, right: c
new_eins = EinCode([getiyv(right_code.eins), getixsv(right_code.eins)[left]], getixsv(right_code.eins)[right])
left_code = NestedEinsum([left_code, right_code.args[left]], new_eins)
left_code = _tree_reformulate!(left_code, right_code.args[right], path)
left_code = _pivot_tree!(left_code, right_code.args[right], path)
else
error("The contraction tree is not binary")
end
Expand Down
2 changes: 1 addition & 1 deletion src/kahypar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ function recursive_bipartite_optimize(bipartiter, code::EinCode, size_dict)
if isempty(iy)
return optcode
else
return tree_reformulate(optcode, length(ixs) + 1)
return pivot_tree(optcode, length(ixs) + 1)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/treewidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,6 @@ function optimize_exact_treewidth(optimizer::ExactTreewidth{GM}, ixs::AbstractVe
return parse_eincode!(incidence_list, tree, 1:length(ixs))[2]
else
optcode = parse_eincode!(incidence_list, tree, 1:length(ixs) + 1)[2]
return tree_reformulate(optcode, length(ixs) + 1)
return pivot_tree(optcode, length(ixs) + 1)
end
end
4 changes: 2 additions & 2 deletions test/Core.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using OMEinsumContractionOrders, OMEinsum
using OMEinsumContractionOrders: tree_reformulate, path_to_tensor
using OMEinsumContractionOrders: pivot_tree, path_to_tensor

using Test

Expand All @@ -17,7 +17,7 @@ using Test
tensor = reduce((x, y) -> x.args[y], path, init = code)
@test tensor.tensorindex == tensor_index

new_code = tree_reformulate(code, tensor_index)
new_code = pivot_tree(code, tensor_index)
@test contraction_complexity(new_code, size_dict).sc == max(contraction_complexity(code, size_dict).sc, size_tensors[tensor_index])

closed_code = OMEinsumContractionOrders.NestedEinsum([new_code, tensor], OMEinsumContractionOrders.EinCode([OMEinsumContractionOrders.getiyv(new_code), tensor_labels[tensor_index]], Char[]))
Expand Down

0 comments on commit a72fbaf

Please sign in to comment.