From 4d946bca445027dfecf10e667604d07bd2475a36 Mon Sep 17 00:00:00 2001 From: ArrogantGao Date: Fri, 2 Aug 2024 19:54:20 +0800 Subject: [PATCH 1/8] add LuxorTensorPlot as extension --- Project.toml | 6 +- ext/LuxorTensorPlot.jl | 5 + ext/LuxorTensorPlot/src/LuxorTensorPlot.jl | 14 +++ ext/LuxorTensorPlot/src/hypergraph.jl | 71 +++++++++++ ext/LuxorTensorPlot/src/viz_contraction.jl | 138 +++++++++++++++++++++ ext/LuxorTensorPlot/src/viz_eins.jl | 80 ++++++++++++ src/OMEinsumContractionOrders.jl | 6 + src/visualization.jl | 15 +++ test/runtests.jl | 5 + test/visualization.jl | 63 ++++++++++ 10 files changed, 402 insertions(+), 1 deletion(-) create mode 100644 ext/LuxorTensorPlot.jl create mode 100644 ext/LuxorTensorPlot/src/LuxorTensorPlot.jl create mode 100644 ext/LuxorTensorPlot/src/hypergraph.jl create mode 100644 ext/LuxorTensorPlot/src/viz_contraction.jl create mode 100644 ext/LuxorTensorPlot/src/viz_eins.jl create mode 100644 src/visualization.jl create mode 100644 test/visualization.jl diff --git a/Project.toml b/Project.toml index b6fde95..7191e97 100644 --- a/Project.toml +++ b/Project.toml @@ -12,9 +12,11 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" [weakdeps] KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" +LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc" [extensions] KaHyParExt = ["KaHyPar"] +LuxorTensorPlot = ["LuxorGraphPlot"] [compat] AbstractTrees = "0.3, 0.4" @@ -22,6 +24,7 @@ JSON = "0.21" KaHyPar = "0.3" StatsBase = "0.34" Suppressor = "0.2" +LuxorGraphPlot = "0.5.1" julia = "1.9" [extras] @@ -30,6 +33,7 @@ OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334" +LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc" [targets] -test = ["Test", "Random", "Graphs", "TropicalNumbers", "OMEinsum", "KaHyPar"] +test = ["Test", "Random", "Graphs", "TropicalNumbers", "OMEinsum", "KaHyPar", "LuxorGraphPlot"] diff --git a/ext/LuxorTensorPlot.jl b/ext/LuxorTensorPlot.jl new file mode 100644 index 0000000..0d46624 --- /dev/null +++ b/ext/LuxorTensorPlot.jl @@ -0,0 +1,5 @@ +module LuxorTensorPlot + +include("LuxorTensorPlot/src/LuxorTensorPlot.jl") + +end \ No newline at end of file diff --git a/ext/LuxorTensorPlot/src/LuxorTensorPlot.jl b/ext/LuxorTensorPlot/src/LuxorTensorPlot.jl new file mode 100644 index 0000000..09961e1 --- /dev/null +++ b/ext/LuxorTensorPlot/src/LuxorTensorPlot.jl @@ -0,0 +1,14 @@ +using OMEinsumContractionOrders, LuxorGraphPlot + +using OMEinsumContractionOrders.SparseArrays +using LuxorGraphPlot.Graphs +using LuxorGraphPlot.Luxor +using LuxorGraphPlot.Luxor.FFMPEG + +using OMEinsumContractionOrders: AbstractEinsum, NestedEinsum, SlicedEinsum +using OMEinsumContractionOrders: getixsv, getiyv +using OMEinsumContractionOrders: ein2hypergraph, ein2elimination + +include("hypergraph.jl") +include("viz_eins.jl") +include("viz_contraction.jl") \ No newline at end of file diff --git a/ext/LuxorTensorPlot/src/hypergraph.jl b/ext/LuxorTensorPlot/src/hypergraph.jl new file mode 100644 index 0000000..3cfe87d --- /dev/null +++ b/ext/LuxorTensorPlot/src/hypergraph.jl @@ -0,0 +1,71 @@ +struct LabeledHyperGraph{TS, TV, TE} + adjacency_matrix::SparseMatrixCSC{TS} + vertex_labels::Vector{TV} + edge_labels::Vector{TE} + open_edges::Vector{TE} + + function LabeledHyperGraph(adjacency_matrix::SparseMatrixCSC{TS}; vl::Vector{TV} = [1:size(adjacency_matrix, 1)...], el::Vector{TE} = [1:size(adjacency_matrix, 2)...], oe::Vector = []) where{TS, TV, TE} + if size(adjacency_matrix, 1) != length(vl) + throw(ArgumentError("Number of vertices does not match number of vertex labels")) + end + if size(adjacency_matrix, 2) != length(el) + throw(ArgumentError("Number of edges does not match number of edge labels")) + end + if !all(oei in el for oei in oe) + throw(ArgumentError("Open edges must be in edge labels")) + end + if isempty(oe) + oe = Vector{TE}() + end + new{TS, TV, TE}(adjacency_matrix, vl, el, oe) + end +end + +Base.show(io::IO, g::LabeledHyperGraph{TS, TV, TE}) where{TS,TV,TE} = print(io, "LabeledHyperGraph{$TS, $TV, $TE} \n adjacency_mat: $(g.adjacency_matrix) \n vertex: $(g.vertex_labels) \n edges: $(g.edge_labels)) \n open_edges: $(g.open_edges)") + +Base.:(==)(a::LabeledHyperGraph, b::LabeledHyperGraph) = a.adjacency_matrix == b.adjacency_matrix && a.vertex_labels == b.vertex_labels && a.edge_labels == b.edge_labels && a.open_edges == b.open_edges + +struct TensorNetworkGraph{TT, TI} + graph::SimpleGraph + tensors_labels::Dict{Int, TT} + indices_labels::Dict{Int, TI} + open_indices::Vector{TI} + + function TensorNetworkGraph(graph::SimpleGraph; tl::Dict{Int, TT} = Dict{Int, Int}(), il::Dict{Int, TI} = Dict{Int, Int}(), oi::Vector = []) where{TT, TI} + if length(tl) + length(il) != nv(graph) + throw(ArgumentError("Number of tensors + indices does not match number of vertices")) + end + if !all(oii in values(il) for oii in oi) + throw(ArgumentError("Open indices must be in indices")) + end + if isempty(oi) + oi = Vector{TI}() + end + new{TT, TI}(graph, tl, il, oi) + end +end + +Base.show(io::IO, g::TensorNetworkGraph{TT, TI}) where{TT, TI} = print(io, "TensorNetworkGraph{$TT, $TI} \n graph: {$(nv(g.graph)), $(ne(g.graph))} \n tensors: $(g.tensors_labels) \n indices: $(g.indices_labels)) \n open_indices: $(g.open_indices)") + +# convert the labeled hypergraph to a tensor network graph, where vertices and edges of the hypergraph are mapped as the vertices of the tensor network graph, and the open edges are recorded. +function TensorNetworkGraph(lhg::LabeledHyperGraph{TS, TV, TE}) where{TS, TV, TE} + graph = SimpleGraph(length(lhg.vertex_labels) + length(lhg.edge_labels)) + tensors_labels = Dict{Int, TV}() + indices_labels = Dict{Int, TE}() + + lv = length(lhg.vertex_labels) + for i in 1:length(lhg.vertex_labels) + tensors_labels[i] = lhg.vertex_labels[i] + end + for i in 1:length(lhg.edge_labels) + indices_labels[i + lv] = lhg.edge_labels[i] + end + + for i in 1:size(lhg.adjacency_matrix, 1) + for j in findall(!iszero, lhg.adjacency_matrix[i, :]) + add_edge!(graph, i, j + lv) + end + end + + TensorNetworkGraph(graph, tl=tensors_labels, il=indices_labels, oi=lhg.open_edges) +end \ No newline at end of file diff --git a/ext/LuxorTensorPlot/src/viz_contraction.jl b/ext/LuxorTensorPlot/src/viz_contraction.jl new file mode 100644 index 0000000..ffb288b --- /dev/null +++ b/ext/LuxorTensorPlot/src/viz_contraction.jl @@ -0,0 +1,138 @@ +function OMEinsumContractionOrders.ein2elimination(ein::NestedEinsum{T}) where{T} + elimination_order = Vector{T}() + _ein2elimination!(ein, elimination_order) + return elimination_order +end + +function OMEinsumContractionOrders.ein2elimination(ein::SlicedEinsum{T, NestedEinsum{T}}) where{T} + elimination_order = Vector{T}() + _ein2elimination!(ein.eins, elimination_order) + # the slicing indices are eliminated at the end + return vcat(elimination_order, ein.slicing) +end + +function _ein2elimination!(ein::NestedEinsum{T}, elimination_order::Vector{T}) where{T} + if ein.tensorindex == -1 + for arg in ein.args + _ein2elimination!(arg, elimination_order) + end + iy = unique(vcat(getiyv(ein.eins)...)) + for ix in unique(vcat(getixsv(ein.eins)...)) + if !(ix in iy) && !(ix in elimination_order) + push!(elimination_order, ix) + end + end + end + return elimination_order +end + +function elimination_frame(GViz, tng::TensorNetworkGraph{TG, TL}, elimination_order::Vector{TL}, i::Int; filename = nothing, color = (0.5, 0.5, 0.5, 0.5)) where{TG, TL} + GViz2 = deepcopy(GViz) + for j in 1:i + id = _get_key(tng.indices_labels, elimination_order[j]) + GViz2.vertex_colors[id] = color + end + return show_graph(GViz2, filename = filename) +end + +function OMEinsumContractionOrders.viz_contraction(ein::T, args...; kwargs...) where{T <: AbstractEinsum} + throw(ArgumentError("Only NestedEinsum and SlicedEinsum{T, NestedEinsum{T}} have contraction order")) +end + +""" + viz_contraction(ein::ET; locs=StressLayout(), framerate=30, filename="contraction", pathname=".", create_gif=false, create_video=true, color=(0.5, 0.5, 0.5, 0.5), show_progress=false) where {ET <: Union{NestedEinsum, SlicedEinsum}} + +Visualize the contraction process of a tensor network. + +# Arguments +- `ein::ET`: The tensor network to visualize. +- `locs`: The layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`. +- `framerate`: The frame rate of the animation. Default is 30. +- `filename`: The base name of the output files. Default is "contraction". +- `pathname`: The directory path to save the output files. Default is the current directory. +- `create_gif`: Whether to create a GIF animation. Default is `false`. +- `create_video`: Whether to create a video. Default is `true`. +- `color`: The color of the contraction lines. Default is `(0.5, 0.5, 0.5, 0.5)`. +- `show_progress`: Whether to show progress information. Default is `false`. + +# Returns +- If `create_gif` is `true`, returns the path to the generated GIF animation. +- If `create_video` is `true`, returns the path to the generated video. +""" +function OMEinsumContractionOrders.viz_contraction( + ein::ET; + locs=StressLayout(), + framerate = 30, + filename = "contraction", + pathname = ".", + create_gif = false, + create_video = true, + color = (0.5, 0.5, 0.5, 0.5), + show_progress::Bool = false + ) where{ET <: Union{NestedEinsum, SlicedEinsum}} + + elimination_order = ein2elimination(ein) + tng = TensorNetworkGraph(ein2hypergraph(ein)) + GViz = GraphViz(tng, locs) + + tempdirectory = mktempdir() + # @info("Frames for animation \"$(filename)\" are being stored in directory: \n\t $(tempdirectory)") + + filecounter = 1 + le = length(elimination_order) + @info "Generating frames, $le frames in total" + for i in 0:le + if show_progress + @info "Frame $(i) of $le" + end + fig_name = "$(tempdirectory)/$(lpad(filecounter, 10, "0")).png" + elimination_frame(GViz, tng, elimination_order, i; filename = fig_name, color = color) + filecounter += 1 + end + + if create_gif + Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $(tempdirectory)/$(filename).gif`) + + if !isempty(pathname) + if !isdir(pathname) + @error "$pathname is not a directory." + end + fig_path = joinpath(pathname, "$filename.gif") + mv("$(tempdirectory)/$(filename).gif", fig_path, force = true) + @info("GIF is: $fig_path") + giffn = fig_path + else + @info("GIF is: $(tempdirectory)/$(filename).gif") + giffn = tempdirectory * "/" * filename * ".gif" + end + + return giffn + elseif create_video + movieformat = ".mp4" + + if !isempty(pathname) + if !isdir(pathname) + @error "$pathname is not a directory." + end + pathname = joinpath(pathname, "$(filename)$(movieformat)") + else + pathname = joinpath("$(tempdirectory)", "$(filename)$(movieformat)") + end + + @info "Creating video at: $pathname" + FFMPEG.ffmpeg_exe(` + -loglevel panic + -r $(framerate) + -f image2 + -i $(tempdirectory)/%10d.png + -c:v libx264 + -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" + -r $(framerate) + -pix_fmt yuv420p + -y $(pathname)`) + + return pathname + else + return tempdirectory + end +end \ No newline at end of file diff --git a/ext/LuxorTensorPlot/src/viz_eins.jl b/ext/LuxorTensorPlot/src/viz_eins.jl new file mode 100644 index 0000000..58b6859 --- /dev/null +++ b/ext/LuxorTensorPlot/src/viz_eins.jl @@ -0,0 +1,80 @@ +function LuxorGraphPlot.GraphViz(tng::TensorNetworkGraph, locs=StressLayout(); highlight::Vector=[], highlight_color = (0.0, 0.0, 255.0, 0.5), kwargs...) + + white = (255.0, 255.0, 255.0, 0.8) + black = (0.0, 0.0, 0.0, 1.0) + r = (255.0, 0.0, 0.0, 0.8) + g = (0.0, 255.0, 0.0, 0.8) + + colors = Vector{typeof(r)}() + text = Vector{String}() + sizes = Vector{Float64}() + + for i in 1:nv(tng.graph) + if i in keys(tng.tensors_labels) + push!(colors, white) + push!(text, string(tng.tensors_labels[i])) + push!(sizes, 20.0) + else + push!(colors, r) + push!(text, string(tng.indices_labels[i])) + push!(sizes, 10.0) + end + end + + for oi in tng.open_indices + id = _get_key(tng.indices_labels, oi) + colors[id] = g + end + + for hl in highlight + id = _get_key(tng.indices_labels, hl) + colors[id] = highlight_color + end + + return GraphViz(tng.graph, locs, texts = text, vertex_colors = colors, vertex_sizes = sizes, kwargs...) +end + +function _get_key(dict::Dict, value) + for (key, val) in dict + if val == value + return key + end + end + @error "Value not found in dictionary" +end + +function OMEinsumContractionOrders.ein2hypergraph(ec::T) where{T <: AbstractEinsum} + ixs = getixsv(ec) + iy = getiyv(ec) + + edges = unique!([Iterators.flatten(ixs)...]) + open_edges = [iy[i] for i in 1:length(iy) if iy[i] in edges] + + rows = Int[] + cols = Int[] + for (i,ix) in enumerate(ixs) + push!(rows, map(x->i, ix)...) + push!(cols, map(x->findfirst(==(x), edges), ix)...) + end + adj = sparse(rows, cols, ones(Int, length(rows))) + + return LabeledHyperGraph(adj, el = edges, oe = open_edges) +end + +""" + viz_eins(ec::AbstractEinsum; locs=StressLayout(), filename = nothing, kwargs...) + +Visualizes an `AbstractEinsum` object by creating a tensor network graph and rendering it using GraphViz. + +## Arguments +- `ec::AbstractEinsum`: The `AbstractEinsum` object to visualize. +- `locs=StressLayout()`: The layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`. +- `filename = nothing`: The name of the file to save the visualization to. If `nothing`, the visualization will be displayed on the screen instead of saving to a file. +- `kwargs...`: Additional keyword arguments to be passed to the `GraphViz` constructor. + +""" +function OMEinsumContractionOrders.viz_eins(ec::AbstractEinsum; locs=StressLayout(), filename = nothing, kwargs...) + tng = TensorNetworkGraph(ein2hypergraph(ec)) + gviz = GraphViz(tng, locs; kwargs...) + return show_graph(gviz, filename = filename) +end \ No newline at end of file diff --git a/src/OMEinsumContractionOrders.jl b/src/OMEinsumContractionOrders.jl index 7ac2298..19a6609 100644 --- a/src/OMEinsumContractionOrders.jl +++ b/src/OMEinsumContractionOrders.jl @@ -17,6 +17,9 @@ export CodeOptimizer, CodeSimplifier, label_elimination_order # writejson, readjson are not exported to avoid namespace conflict +# visiualization tools provided by extension `LuxorTensorPlot` +export viz_eins, viz_contraction + include("Core.jl") include("utils.jl") @@ -41,6 +44,9 @@ include("interfaces.jl") # saveload include("json.jl") +# extension for visiualization +include("visualization.jl") + @deprecate timespacereadwrite_complexity(code, size_dict::Dict) (contraction_complexity(code, size_dict)...,) @deprecate timespace_complexity(code, size_dict::Dict) (contraction_complexity(code, size_dict)...,)[1:2] diff --git a/src/visualization.jl b/src/visualization.jl new file mode 100644 index 0000000..6f0e5b7 --- /dev/null +++ b/src/visualization.jl @@ -0,0 +1,15 @@ +function ein2hypergraph(args...; kwargs...) + throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`")) +end + +function ein2elimination(args...; kwargs...) + throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`")) +end + +function viz_eins(args...; kwargs...) + throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`")) +end + +function viz_contraction(args...; kwargs...) + throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`")) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 6b2a6e5..54acb8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,3 +28,8 @@ end @testset "json" begin include("json.jl") end + +# testing the extension `LuxorTensorPlot` for visualization +@testset "visualization" begin + include("visualization.jl") +end \ No newline at end of file diff --git a/test/visualization.jl b/test/visualization.jl new file mode 100644 index 0000000..590a036 --- /dev/null +++ b/test/visualization.jl @@ -0,0 +1,63 @@ +using LuxorGraphPlot +using LuxorGraphPlot.Luxor +using OMEinsumContractionOrders: ein2hypergraph, ein2elimination + +@testset "eincode to hypergraph" begin + eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a']) + g1 = ein2hypergraph(eincode) + + nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod()) + g2 = ein2hypergraph(nested_code) + + sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA(nslices = 1)) + g3 = ein2hypergraph(sliced_code) + + @test g1 == g2 == g3 + @test size(g1.adjacency_matrix, 1) == 5 + @test size(g1.adjacency_matrix, 2) == 6 +end + +@testset "eincode to elimination order" begin + eincode = OMEinsum.rawcode(ein"((ij, jk), kl), lm -> im") + elimination_order = ein2elimination(eincode) + @test elimination_order == ['j', 'k', 'l'] +end + +@testset "visualize eincode" begin + eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}()) + t = viz_eins(eincode) + @test typeof(t) == Luxor.Drawing + + nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod()) + t = viz_eins(nested_code) + @test typeof(t) == Luxor.Drawing + + sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA()) + t = viz_eins(sliced_code) + @test typeof(t) == Luxor.Drawing + + open_eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a']) + t = viz_eins(open_eincode) + @test typeof(t) == Luxor.Drawing +end + +@testset "visualize contraction" begin + eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}()) + nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod()) + t_mp4 = viz_contraction(nested_code, pathname = "") + @test typeof(t_mp4) == String + t_gif = viz_contraction(nested_code, pathname = "", create_gif = true) + @test typeof(t_gif) == String + + sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA()) + t_mp4 = viz_contraction(sliced_code, pathname = "") + @test typeof(t_mp4) == String + t_gif = viz_contraction(sliced_code, pathname = "", create_gif = true) + @test typeof(t_gif) == String + + sliced_code2 = optimize_code(eincode, uniformsize(eincode, 2), TreeSA(nslices = 1)) + t_mp4 = viz_contraction(sliced_code2, pathname = "") + @test typeof(t_mp4) == String + t_gif = viz_contraction(sliced_code2, pathname = "", create_gif = true) + @test typeof(t_gif) == String +end \ No newline at end of file From 3c3f608b0898d7a4c2458009fb956abcf099580d Mon Sep 17 00:00:00 2001 From: ArrogantGao Date: Fri, 2 Aug 2024 20:18:40 +0800 Subject: [PATCH 2/8] update readme for the viz toll --- README.md | 46 +++++++++++++++++++++ examples/eins.png | Bin 0 -> 28984 bytes examples/visualization.jl | 8 ++++ ext/LuxorTensorPlot/src/viz_contraction.jl | 4 +- 4 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 examples/eins.png create mode 100644 examples/visualization.jl diff --git a/README.md b/README.md index 1b51d51..5fb9f7f 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,52 @@ SlicedEinsum{Char, NestedEinsum{DynamicEinCode{Char}}}(Char[], ki, ki -> ) ``` +## Extensions + +### LuxorTensorPlot + +`LuxorTensorPlot` is an extension of the `OMEinsumContractionOrders` package that provides a visualization of the contraction order. It is designed to work with the `OMEinsumContractionOrders` package. To use `LuxorTensorPlot`, please follow these steps: +```julia +pkg> add OMEinsumContractionOrders, LuxorGraphPlot + +julia> using OMEinsumContractionOrders, LuxorGraphPlot +``` +and then the extension will be loaded automatically. + +The extension provides the following to function, `viz_eins` and `viz_contraction`, where the former will plot the tensor network as a graph, and the latter will generate a video or gif of the contraction process. +Here is an example: +```julia +julia> using OMEinsumContractionOrders, LuxorGraphPlot + +julia> eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a']) +ab, acd, bcef, e, df -> a + +julia> viz_eins(eincode, filename = "eins.png") + +julia> nested_eins = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod()) +ab, ab -> a +├─ ab +└─ acf, bcf -> ab + ├─ acd, df -> acf + │ ├─ acd + │ └─ df + └─ bcef, e -> bcf + ├─ bcef + └─ e + + +julia> viz_contraction(nested_eins) +[ Info: Generating frames, 5 frames in total +[ Info: Creating video at: ./contraction.mp4 +"./contraction.mp4" +``` + +The resulting image and video will be saved in the current working directory, and the image is shown below: +
+ Image +
+The large white nodes represent the tensors, and the small colored nodes represent the indices, red for closed indices and green for open indices. + ## References If you find this package useful in your research, please cite the *relevant* papers in [CITATION.bib](CITATION.bib). diff --git a/examples/eins.png b/examples/eins.png new file mode 100644 index 0000000000000000000000000000000000000000..a02363f16e0c5a6c8445a9ac3ff49d0af7fe14e8 GIT binary patch literal 28984 zcmYhj2RN4f|2}@%QdTKNMiiBdLS*k%i56uhA!L-5c}JooN>WBviISO-Eh-{ABeG|N zLdg1`_w)IGkK^ZaJje5S?)$p$`+dFNuW`Q4*ZI2c>gj4SGH@^u1i`4SrFNPiC@JvI z4|+wJ7Q?3Z>b{k00;dyZ3QaR6#+(;!K}kXV#^mPodkkRn2T8 z{DbZ--cAg9dSWzicC**&A{+a5p;LM(Hm7&CZ25igcmBKl2RRS6*BMOB_9mvJXt-}bM%9QQjub`z>S_3+ibGm5LN50V}fRk?;<3Ke`1YglG~?7|)j5*FE1 zIy^j_miy|U{aL*=^+=Vq!7Fcfy&akR(jj6}?APJTokB%4B+@EL^_JYce565bAlSc! z@}-05^y`U(+-6gg1C#&agqU+EltjrF2{cB>IzG!RJ6RgkQ81Rr{+%u5)2p99e{SX2 zLU=M;AM3Z{mh%XI6q7Xi@sn=myx7(4?vbHJQd^@Kh(ekanS7|>W zrUp0P)X?1^Oz%zT#Xq5J)@M4se^k0MF?YTea{_J|an(-W>PuO-f<&-nR1 zGmm2w{wjJ-Uw`1LF=4(9yD2$W_RL`NOx@Jdj3~p|d-oH!oO`5`s8=2NQSl(b#xU~h zVS%_+gAxVN?}s}bi;RkL?9Tr+`ZdN@wT;6{a=yDaGe|0bB)>G@H}3h+) zzMiW8vaql)9n)^Hqeq`aM8J>Ayp|oh)AIj?EXcnvc$4zR-r<7cN8j%JYj(D_d@=`# z`K`9Zs^1pcp{cIisymFwb#!!2oY+E7uXpxrTT@fvjS<~AuOCJm1@YY5*Yqaq+LN5s z6%YN}BgB9F+_`T(J#vzgkp^89q-}TgX)M@M&eckVS=f{4+h zB0Y+X%)9*gL0el}sn=qQ%{x((PnJjc`T5;fW<(C{Kg*qYV%LdFTH@70AG=R?%LvIB zWEj5xYS*pFbGPp6SJ&&;IV5e{c>GrP2nuShmNO7pbdpLhN_|$`8lTIvix|(&&Mu8a znZ03P46mGzRnef_s3 zqRA+gg75Xqmm^ERhkSi~>z)XA=h;O)I=v|W!En5py1`pm(w>>%reJr-YUGYzLNRu*|YNUjpey9em*{r zuIc8A3f?_?l-1QgVv~i1L#>A>D-6uJc@O5~<`y{itq*<-@mcvRY+M*#9p-6aLQnF; zo2$KLK9kea4%XJzR#t!ij2!0()tEMjixM?k$7&TfR$IHe-oJnUta>j&;^3_}bX;ZP zQ7C!)R+yi^sdGh%P2_1LaO*H%G@g{zd@mu~?(6|1lPW1Qh->YWC zeztpwi37RTjaeepN(^np9{n?ChNh<#d{;fh#dCXljB%qBwqrv4{M%Sq`hWeBzxijl zzu!D(hzMqBBW`sxH?y*^nA_OYIC7RxU3YWyogZ)hJ@nbxl%8a*M@7AJlHY-rmbSUM z89#UH)-424%b=gzRaL8Q-=^NRYZv)^X#d`(`PTaCl$3pE)3=J&Fsa?&|82;P~3s z#x|)!^aS1{%r2LXjf{=CkJdl4>&UokMUk-D=VgiS-?3wdA=bs$MksvzSv`piH9Gnv zJY4M32kM+{57SS@D^ylinwpya{}+{=F*2GQ8M$)p+BHkdy-tp+TfcU79XopTC@*gX zjOyjfz$s~hl39`zmUYyZamyC$gl_vxMk4A^@P1#f9PlNy*ZKYsu&^k?jQFuHk0 zalzQc#Ao0G4Ldvgl(%o-ixtj zMyIBxY98=w<>chpNPK^K__FSExmriYtpqh?yMGPx{MhiYT5@x2T->&;TlI)eN$;iE zr6mvAfJ>Jy-E!SSFeT822ju;>yK?2@_a#?X*X`S?Hg?2@M?}c_t(UxdHTL+x85LF4 z+6Pp;=Cg}TnXYbbv@|rKyN~5$WQg$dQ*o9X4@^z%-nK2*Z^IX7Yrff*IAv1oc5P{P z;NioE?lZl(|Ea!m#jd-zgC0KA_x6^H)BRk;qSScl)Kksp@?QCus_wM6x2py*os5-! zI3@jI3f6)Up>Tb;=Fz_6H4T&%GkR3?2FAuEu9NL6i! zjvroii$FvR539z!SN!Zz*PFpmW1xIq=)vg@@tgzE&XbJ z#oU#s7E)5;ZDW&)n;ou=sQU1s<^fg3tE?T>J{=h*bF;He?d=NKv{e0M?DJ1}-|Y5O zik`<$pB7^2-@ji}R4m^mZWYc&w?6ix*xkXw0cOX_%F4I0c8%2d;UbGW8#}u!{$ylC z%*v1>w$qR@v)TiFmq*x*jJh|oi{PWAumsxq_Lb1STU2ybSFXG-D!LaE62ZlwP}SA7 z-2T=GIqHU+o5Qth`(&92GbP6U-F8DAH8rUvC2)-&9UTU`x(}xGW+hWnQfg~!-0Sa~nE3KE5-pg;G+s3_*mo2!agh|Pa*cRj%^ zIS+k8ME$qC{5AGi#FTVJ;^W6`2+%@8Zz?L3VtZ&F!i_kQ7vbBxxwxW)jh;Mta)*+) zsW)?T*Z1#;y=psbmBJ(>C4cq2Z*6Nkd*Q+{aq(@OlI*CMFy=_az^98D3uV<|K7X zYl?oFHd0vRlPA+%xnn>N*ixfXPs#Q*bw)cAn?J*~pC52VJb5C3Ez8c@F~5<0)K@c% zJ2ccxA8EkW#wII2zoD@av5CmmquS5L@bb+Yq5b={%*@IY68cB#qRPt3{`~o48ARE# zfB$~tLZ`~FU%!6&a*&VDA3jbEJJ=^94fr@sEw99z--MgE*oJnDO1qEe~9Xt8# z2swqGK7E>OqzVr+D73Y;Ii(!Z)6;wM?ASv!_+^=9lRtg>gz!2%HWm~VG&wQx5N1=S z<|3T>?p@p0uh-hsFU*a8^+Kj%AiVGT`>S|(6ux=$CNlCIVib~B;fY~er5iWqU}Q)Z zKo|C_Dg_8Rjz5Yhu*_}Ci_U!rd?JX^QE`s^*-!B*R`O3mm z@voSus32Frei<5?z~JEEgaqE>$Lsq1HV|F|v9x|_eLLlkA3sKknB2R2S6fTV#?cY^ z@gbnk!-r|OU*NUWR51yOce%MS3O+@prSj~P6akgU_a!KsLPNh+2GHf@<-K39BsF@j zuPy_CSbvR1Y6)`P!y;{*L2*K0?o%jd5HpYO@}CiZn$0OEo>^SJEbi1_5hh}1XLtPM z$xY5Hq=IwK{ey{5s7L0_U(W}FJHc# zn3y>JwNLz#h*XdIXPoGGR-2m=o08WFQ_swFwD=-|NYhdIr3h$<&hiLue-Roc>DO2UuTTd zW!|)Dm$Vaas0N&|zrBet{_Sfrj`M)Sk85PwRv?w&!@9B3m2V{u;YP3(pTl@Irs8{x z-5aZ`tG|5tSsNkP!^KVt*}82T5@5CnwfBR$8_MalJvYW0e&k%)?EXps?w*=j)82k` z|9+cK_jk?sJq04aIX4P>1tK2|6z%sz=z3$2DIy}m#LNt1+9S(UVeM4LlZC7HsNP}G zCj+=APtF?}J`D*ue)HzBK$QccqV6|t#7H|Q0d@8+(I_XTrtaLe>%Dzf4#Kdyx;i|( zUtBttuKM$5QR6}}d3h6U?YooDdb_hPf3AxX{T#|UHa1p(%aM2B6kDZs6Uw*=Z1+@u zKWVXs)aP`X3e|jP_GNiFxhID%sHmkZe*gY`3nOC?6W717ct56HM{YM|6ZCeT~sni#~NX zBqY=L{U7A|l3lnMA3`ls*VtH)mnVGa&|X=lV3~Zo3ZV3%@$vn_!Y9?D8!r~QTs?c1 z?X$HZ#ZoA2JkPGT9=DyaiE>u-8aOxM${VT+jk9p4(lq_KXvMq zhK7KwY)pVM#nQG&9aneva_{9orxe$ak$StjwB7hv^zA!H0J@xUlEvUu_Iem3N3yvgK1-EWOD?Bt0P-U~msw+E6AP5D^9Mc~iNTWIwqeiV5xU%z~r z|8uDYb{J+VUCd63Pka5^wA^ zI(_NXUu1O?6r~K7VWcz6G`~Cs!hPO^-*;Bg<=X z8va>aSXlk>qfcqzt~lMW&`^1=zf*|q2&zAS{d$|2SeZA^o7LIPJU%tGcZ72ds{?IP z!oyCALLHl7R^A{Z>N45>Hq-dRg$wUpCQh@nM>&?S9TgP?!n^k4Zx&!vyu{TOGru)N zw78cj1IuGwIp!#b|N>8`HaYI~Lzibe;j}?=@_j7MN`@i3xdCFZ(_X`Mos;@tB z{J4<3ybnr`sr8HP<$fChbj%8#^ZSn+xeh0ey)(!tvhXiynVbQg$zR^7RCXHU6bAMOJg9{8{0p|W1K0oA*A z?_R%ogXDOO^S;i@>+M&Mpx`z$i_cZ2pnez~{XLXZ3ZM!04YWi3CBFT8mA^eMo?`-XdGkH&q>+}Up9iS{UxmPY`0!yQ7Xgc46mbeTTLfSWi`d0T zO6BUm8@>46^6MiEtomicz^Cxv{RD3g@-f28-G#YMS;J+zsKKr62efD zdPqa;v8tE>5|5JyZAELVI5Ml8Tx{`Gg|?oalItU1_%$EwCr4)nLh-n+?)3OLpyX&* za6?0b-^Q9$;~lCNXG5wNFJGec@LgZF9#ohpar<@g`0?Wa-u};-D>krTdFzJV4ZZ2Z z^@t;=i4meaj`D?v%;i=9@`K|-;irM5!cLWKkBenxR{(;B{zXtwYuVdRRo~s#RVt*Q z;0rU^Ca5dtz2u~>PGhuqI;~ARLO|YQmh?keS!v;?PdGNnXJc(KLg3_}!{!&ix(hTl zG(ck|OU0{FuuzJcl>ymbnS6qh8A-F_+qbu-rT_jdOzxI=32UjXsWD#^VbLpHW0?mM zybvdZ$`Q%7ukMwklvJik@ljMkN008fCZbf07o;R4vb@yBfJOj`>OKD0mkJ2C4(l z%??y~ty5D`QQ2<}VG=1}Umprnq@|-nNnitzj;edNA{qU_=-@TX+;ZRsQ{TRY_g--q zulV#y;NZb~3ypAAn2o2Wr@g%L$&+vM^PjvmQgU)0>*gi9#QN$@E2~k|bXekqi;Pi0 zx#xoIx7X)%b#-C2|K=JMv#F_-67`~EVva32UaAS-gMD_9Swn5X&CShu)IR*tBa6vR z*RNl<=@yhJ{Ogm7%1uK<1H~~s`{(53Ojqu!?#$-=ckfp5WpiVV1=-on-@bjVuYb{O z)ANOjcYN=@eTvyXU_nc!g1`CIC>N%8W~rgJBJG3X?AvY zWaRcTT87zuTWD34l@pVa{NMJ4>&7em8W?a!npm9FLg|hGJ+Zt^H%?{&7!TW<%g#={ zEt_!=mRY}V$4Lp9ZE6s`8f7g%)CbpTJE*z@crFY zxt*Jrs7EOX@2`LzA3hAiHPx8aUbU8%l)&;3S!RFPG*>2bkZuyhXON3iV`G~0eRc2% zgEMD@4jkA8g9if{6T?~D7e}jH9m0WBgoaJUVK7;ajx+m?KXLjy5&Yca_h4XHSR*_R zJQ?bJAT(4enh$nQjEsD1YEo)Ss9>hgq`*1DJbeDm@A2d%_T>Anc}hy=#ryud8ysYC z_Us|UtbQyIhjnF54Rm!!ht$4(FEcW9&CE6`Zc&wVg%iE}CA6rgEX0X~sMND4d2Mub z?nOnhyT==!IU|m}LWWik+2J#vNB&iJ(l8lYT9*855|f+VuV3W zZS4meVi;4>W)PMM(0vH9R1bB7ic*hnV1vx%Gz9UEednyGP(r> zJAeK>+F3Ir;}wE9GM2rPl9F)6%b)Ha78gg^vfp}}KPWP6sfUM$@%!r|K|H?W%b|b& z{%wp`?8~|o@aD}V=){PKh+{jsmKboVl*y%~K9&A7M)~%~oCm7`C1LJwGBY%0UZb^i0KV>d>nj~xSu$X;RT=fN!{&=O98 z#l^)aneZu;0%c|EUt^?HRaH^)1#Q`tgF}YnjE#=^7G3KG;SNVa;sZ1Sn^ap@2M&9> zx3pO5)dIo>D6JbeZXlG{+uNf8`BGQswJ<}wcTaxMmOvF}CnrOMu=ESL^jtJb<0gep zmsbA%!d-5>t2MZG?N_9bfrW*|bGe&;M!#~NJb4mz28a{@L?A)XXGmVMvc-;N%T`|= z9~c}Q>|q3O^gn;HG;RrhgP8n@kN1{6=&fisV&R_F&v8VWa2m zy_%pLKwfz;QbZV+tQ)$MeZDKgT^EWTMnvRZ>&}05D&CMH)7x`9nz)$N2p zL>FHE^hv|}!EVxJA%l#^;o&R^wB>j2-YsAMXNLym6Hry6W+DQgLL6HN>)^G%l>Gej zy1Gm>1nIKispluOwU7C(N_Q!DnqRs^4s19KIy$;AZh4s_M;=Y7&8k;7P#rmPWQXv@ z$IqYZ_OK8tha@Fkf3~F|MBce`$I^O0&&A~(m>D30{}=?}F6aKs_>hwr8|gB96QBw% z%i2SUtOC-*ti)qrdb-Twqman3Wibeo}e4-_@&W@$o!e%7q#_I$yqhlN1$g zgn>+TW~0VHY6LMgI7VT(nbO<$bbb{hi<>SdqjzIBf0h z??-L&xZk-*B_Sb!D5ghCc=8DdSOL1@2&w3p_R4NPmIQL-+V{Qu{6mwIcGlLKZo3~U znKLG|PQX#FIXHZhQYDJeSiI!m;5%Bs3yerkZus3wb2W8!^b01($Cs9uL%HRnr_@v) zROBUq<~%4QluQsVV9DxgYbW5>Y`-a$mdm3y{f|tqe zDk=)Bijp!XV0QEQFpl;*YQ!Y$F~|?gpsl(v#@$sD!ej^BN>xQ= z*5_*&YBqopD{E^ZA)&C{$6D&@-Xtd8n$r7lr(&aVz*lO$%ms!?o2#0B;-a?44Ji2*kS&6o*{|NO}(`&f+E zjhmFmOUiU|cL!3tvJf2f9F*4m;NU`!-xeMoO8`H7{QP&PZ||SSaVtFl6OBAZO-C2O zr8aw?OZG3qx4iFaA*}i3%PVMvcHO1ki38UPkBM1<*TU98Iy{O~*F&YG=(p}o_EelY zIgCDfRHM5$oNPsQph1z?6`lG?X=w+#E>Lqu>bz{WL46^?aSas^PzwSFc22AlnKqBd zcd4?abmW}rF@nVH6>!-S$+LI<*PX; zEdDQY1bajLLA^DN`#0sSZ^>k)LnHZQyuy##+O&iO(v+Sh*k|IFT9#&s5@4(6#>Flq zXCw~Te=UQq-zEb-v<_M<}qvKHNFSU&57zN!`(Oavdn=Uw1p z$nJ`6Q%}Rf!oaowfYj)!@g|K@bh6Mk-{$_U67VhgQ&rXFuhEi7R428y_sFvDA!9cV zeEdjy)Fw3Gu1Z7)1O!Axw1TIrVfuIln`Rt{qqN`-`wxt&yHMS=yL9? zSmS8I>4NV-!kt#wL1G3a3m^Ub=~I%rnAo#=G^`zd6uF|P0p&0{I%+BSlXW_i=8`y} zM(tT$6_k4RV5ml2O-(|x4O$8@GOj(>vGxx2Owr;XfE>75Koq;LERhYsDJ;;jfno&f>*XJw`wER?8WmO1!7Fv^p) zr3escjoy6rB7~(h?2P1%QW4 zW3%(;^FMW6IAbDj^FzkUDQjYFMd0w^`TqXC_6)Anl)qI!zab0XV(Xd@x-; zng=WY{t4{chp@oNd2|ZE7p;$Y(;fd7`;r+MzwyZJmk=v_o0xq@wd$nFnwRT>2^DP* zr?i-vnc4KT$z`9l)DZUfoMGWN30s0*2w&L%ZVm+>P;PnrFa4ohuf@OdzMf*1HS4}> zkE=f_mYSSX88qKmcc)|gzVySLB8d6M;hlTth;*7;St?;34 zKmbisHimsi`1r%-UW^0Jzur@ z*KMpLQ~1i2j@NHA+rIfs^d8R10T?cAtQc9mjfkvr=Ntce%gf{Atjxv+zei8{(GDf^ z^j@`>c>HivOG}d8enyvn8KU!V+su#D0UpQSsw?m4H{Vac4Y#bKU03t&Od8Lc*ys=0 z$&HoSvTsr46@3~nzjz-B>L=j1W%O9S9tQ_U{JKDmJ=E{F#J^A^fKa4(EULOhO zSo{x44Uo8hE4@SUDsOMIH15=Btp2euC8awwG%_yEx;#tpg!Pg5wr|ca)hTd2wvQUy z=b;4baN(rzbU9bSp^@RK;SuZM5k}ADoUo^zpO34vF$AhCFaD6dCj3@8fObKFDhjvD zL#Pl~kaauSC+)&U+}>H!q8f?DJTmj=yu5f0A3mvFd&=?B>HWnV9C8o1M|6lOI>LJk z`4i`TR<6DE|1k0Wdzi1n*^B>jX-F!GtLM+C@`#x?gXG^-v33(Ttdsy7769F}cW*cs zwVHon;nGJ=DFJ{0H5H53{wf6>W{x(qS8Y75ZzwE8P^rfbj7kkXJz)iP698kgAr&>- z8_iYF65yP5-FC3-3FA>r`TG3V+ZXc@uYX4EupU1#IuIS5Jkx7ZbTcA^XJmNvq8>5? z)ISUaed?n8pguwRY=7%AVW44Bje0qNSi>%ovcQjT**n8Fv zic5`A06m245QbD_teuL&18^B&)~P8$nSz1>eV=C6kuQ-IRYCE41(?q0jf;xtUOFPQ z^S(i7=#a2*If^~BN{x9Q#;i~vM)^+yHNlEr#Go|}02>n>4OJ29GQ@P&@iA*~GEn6- z1^1+ruckbxpwQdcI0N>{u$J1FS5VN}!QpvSlmTF;m6gW0VNVZpMWr5#8oqH=+t_%u z>UkCQfE6S8HEF!&eOyH)>CKyvas5I$GIFo<=l)$$;BkfuP$?;7LA$WJT5Z!Nk@7>E zn!NlAcHCYgSa``F)17Vi$0h!hIMj_zhGftSY&c#(M@d0`<+qh^uw7jOnZnlBuPiz^ zt9U+a>BWn^Fm_)7bFU9TuD4*Ty!)x`h3V;3Z*3%M%F(kpJ3A}CCg0U6v;3ek3B3dQ z7GGEcsWIELT2MgX5f^oWCbhj{LQW14Q-skes(ulCFZL+`A7`NGX0E8Cs3?F@Q$Hk? z3f|A-?Ki?={k1CE{hdVUm$rFU2avXAagc5(R|J7F%)i$C7n==`_9%{qB=7%jM;nUU z&d$`Nqz{p=D4b|;z0$T9@&m75A4DUr#sAuskJ~0Fs8k@jNqtl7Iwvr*jdGTjx!u1|k1A^iwK1E^ zpW%h@aMXLH(+bPS)J+vUYr^AAOd!$t6k|`w3GAFBN_G=Hr=i`s(wZU^9jz0;vaoT6 ze`S8tR7c}8jiE~GbA-n!FB!@_u=U0>r(0|t6}fhwFg{4+1c!ud z)?lo-$TDw=Hol@kM*4=a!Oyw8Jk{1#$(ELOw(WVoe>Mr9QvX=48{4m~K|Fti>i_;4 zQBRey2&^EbdttftT*BigJCIX4Qq$sgAK5$lqqyA8PO-?`Ki1gy(k0O0nU5%;rj2Dp zC24i!8HinDZK|Y0;`~)pHymSv+S=|%MlS3YL-^FvOAF-U;xXY>e)-_^+4_ctPcwYH zmL6F<+PXPNUy1%FUs_K8X!MzVm6d#1EnXqr@A*Pt0dC2x^sv2_bv(#9qb~GrN;sU#(G4aFGHZk0YRzb`p!BI z84{5)%BH@NuYYQ?OB~+$jFAzH+aRRIy{sh1PXqfd_3wbfr{%$e)xUqkdRcZP>xqp2 ziq61!K~-l(ihukWn6XyB2=)N$R*Lz<$D5=|3hwze8&zH!p_ow-q> zD6q88>k>)7SKcRJ>p*@fY%DSf5(NO^&$n#4&_7;iaNbS(d^`QVeJ3~FbrZ-f$rXQ< z?d%`3`uz24=oBxpcc?WF(vTEjdmKe!p|q>(0ucT@Ga(}5G28z=&u)A9>@yppxzc^i z9|RJI$C49t7pmLYB=+u`aPc%X=tuwyxJ&l$3L~dg~QV2$-|d-_m!P&ZZ}O&;@jwZkojkDFy!1>QyaU+x9RVWN`{7o2?`vf zF9%}=UaqdHDggK8eM_`~rqpxci}Pr`;{WxfxqJ7baSl*Q(lGmWJn#Vh(bSh z^>4Q$Dn~)wTzPnSxD6C^F2N4#(F*Q=)4fgNy6NI#8sOV?4eMWo?arNBX>VKgS(kIC z^sk|3U_cw*|KUT{E@fhgW_g7|6JPZ9MfJ3*`+);oT?JvSTD~K6P%2~+7qopIo;ojp4KY~3)4nu3+*5OH}lHp0&|}J*%LZCgSKZ`E^eW%f`S+w zTcUY&yqAV2?Mym!YHmO#juKb6Mv)+39cfD+V= z)xvVuRsLMm)6@)V6D%J(v!5+O!3Pz~g$-Vf#5V7Y6WZd>P#g8=Z2IsEHg43!#@)Yd zZzc^0-Lu})-eW7}^JaMkA!pL8a&qEL&xi=5q%DKpsqpJ5^C<>j1dK(#E!n*t1T@5t z8@Wyj{~RyWGwN1)#gm?z1x5Rcah%rGt0;MgFP>mAMI{QRy|Q~R5xD=qrZ_b|E#inCw6!YNSJGRJw6QRDYuDBvKeubaJ z{0mY7)Vo`^mj1M+<4v#e)2zQS+3pt|{YdfVCZ+o-hr=H(+f#V-wp`D9e^WevDM1!0 zG!W}%@2~G8+Xx#{V^JCfmf;tx@XOG2Jta3+FwCp9>sUdm&D&yG4w6y<$O1@ln#8vD zA17ywHI3LIe;P&mUi&j=AGiN_71DCi=-Rz|J2EbAXWQ*PSiMPPAN9alJw2zzX%hnj zgPZ^~s{VBN8y?;G_XZ82GdemY2Zihu*;Di1)aAam)zQA|Xl`C{cN>|vb{@Q4Tr3+< z+3#k~vXzdRdvzd)N6d_RiNQCGf{oz{;Q!TQ$6U{(jE|1$dd*pl#YibT_R*I56wiH) zar{!}dgBIX$_eedT4xA+R1!fde){|wOyX15q}O-dN1hYfu>RJJ5eHYjAOQXQroriY^SaE_b;%Q-2^UOB{RU*a(ut!@lE> zgrLJfyoMMF5|2QEJ$xA$VnCbRX6L4+ZQkz$dya#A!}>NeGee@q|IcjG(ghIVadBK6 z8kF<5XWIydz@)UaU7VaorltyZ?deD)&qFt_QZe#VbgEq6ATrQND)wH^fs+31+iUvz ztfE1*zD^K5fTA68A(WUmeUl@pqd-9S3JtB?ygbl>!@tLs_+;XI(N~kWPd$EYPdddR z5G<9ImC&8AE>|>P4SP4vy`cBqc8Xuc~^AphRqca0&zE_B9wgvLU0*4yp6YVfR#pN5^d>LQ70?=f%Cc9*6r$F9ckX-wwFXMh zs{X0Inb|3aoxDuAV<)Aj9sa!$k&$aSHk4`PKpYbTA!2_S8o1~wqjZD`(C_4~r&>x1 z;)&Gxf%6CP1#$X^sw%+~PnF#G`UP-P{r6`xX?(^Z;Fk6Jdni-`Ds#E63nV*$_D2yB z&;!>pQSy4+9`;g}mqK?1>K8Po$-Q~}38q*ez(jx9 zx*-q7W1Vyc{@_a@N2A zDnn^^&~}iMB@u*^hOu#Mr8&#L7#YO|Dnp3pAf;?T)?WMPilPwGw{BRy#)I9@A|uJ& zQ@0(2^1cHHpncZqdC1@!OP7_!wD}tLcCh;HQcnr5jLUe_z3oDJ1mPJle(lnF99;?Y z{RK_~AJE$J1`%v&ee1b}Lq z^*?{+m1UZ*!Q=)6&ta5A({VK(#;stuj-#GNFA*{$>=-(JIGv^F(NQaT`O?l#LmwZX zoLqwb=zoJs#$a>&pNpb~*|B|lWU>1Ul(!fgc%PH=Xi9ok5;lth1fq&@c`!ICi&WGp zw0is50F=b78|u+6MX&o~2)pIyTaA-bQ<#rZs?s7dcs+%+!flB_705oHqaV7y4!Li! z^AIiiVB#t&DkyQtCnuywr4XQzZ1Vp46>zo4M7*+_SLxz(7w6_w@ih*6{@JH0wZpwE zu3lY+iVaL>e-?|`dyn5q_wF&2tuj;yKc(}2n33Ut@5K2A1XSS$AlYYQAR8*DZAC{- zU7cB}=Q(tn+uQSBytq9joiq)N61t{v?+SIp%|GwlzFj?2zA-&GXmRaY+~db;XeEsf z4GoQq1d94nIkAqSs-B+a2ETS=v>q;lCWl%!hX#7PWGmmEWJmZdgikPPYHC7Ag^Z5= z!GM#@%&)H8&#kT75CaM~+bP@cl=pOj3gAB-Hu_b9UCwG+a4j93|5nP??%$dGUkF)2 zBbJU1B`0mpP%r}zPduo}gb7ZfBat@iBm_gya{woxbB0!}%z&CWC;41azv5*7N;c;!wuaKH!&w z;(v5}JowH-qFR};p9u|N6y$IsypD0JC@d}kdTxwkAgTkHlEkVmvUKRBwZ+N0GjT{@ zXi7v*Zg-#5^h?OXpwvjjT_%b#_vg%@aRyigJ{r?!me#E1h$x60-6c205iW1NGhL8J6jkUcD~S#K%-~Ip|_+9jay8Q!ODWE>#000L6HP~D7X*V zn}7CpU08}k`=g*h>Szmv2bCmA6Fo4<%E`t@)8qv~Z^A0`w!lEKPRlmzARKAIRG$HIYSk4Z~IwDgE8Xg7Y8EEGT#nB}%5mhm3M zg0GX4#sAZfY{d;-L{~~nD;QA}#-PU9K=Yes)26%VS^mHg`WPuu zA)&g~R=CY=@;2gJAr8^NU9$$x-1!p+soUp|AHP1AmsM2U0NQ}6#Zr(@2RiD9{F*4( z$W&G{T*mjs|HXwENRpuW4Ce~k4;3of0}!fiu46tl6htg9uVOU(VZn$j4<9xo z*1kV{+l zK4W8N_a-e34g2Pad8kPs5hm-bfq1~zB9OuspGiCOvD4l$EkfGHa>|kN>0UiMKFbo)IUH|m3MKwK>??_VX1Z=piwH117-0R>W;_vPEm;V>MN=sffsK36R?=y z{m63-9yf0W1_nZpd3Wj%&zrYzp`knVpI{i`6p~Z4wDTbPnPUxcbuj}FNWwpUIfUIy zqMbc!Y`i;EqZ_>UVxQkBw0jFj8v3YaXEDi=2FZEOlLDEYka zO=_wQ3@t^_j!Wgb~`0$jOQ0Cija zcyqEoLC{CW$16g4>E_6aX#jES2MAO!#y78DcXGL41{M+i-#^JtX5NV6;uX*ouTIAN zMTz3;^YsC*i_DL{z752{;k71$?M8T$EjcXg$urtUs=eHb_vhHrn3?{#Ww z8(76PjyB>^v&}mcQ0P%tba$`3iuWD-Jw7lnfNtRU&Miu*5JMe|Jh-Q_x32(I8CuM! z2iO?Yp%upyD0vjD97t^N)|-igy4>ng->SW}{EZ*Q+0 zC##Z#wvNt`oe+yuwwnM?0qcWy8c6kvtOn9u$V0z6654diQ&EC z0C+p{o%b5Y6!IfdC_p+jH8rfm_kV>a}d~4@#$lps>l0xv}>k-B5TFCeb;6*91-sjw_pH2FhEiD8M_80a(h#g6Y9u` ziHV-N&6r-tLl^3}hX4qH<);50nc6fb3a@H^y#`-|c`B^(3q;(6#t8cXmd4W3vaqla zQ?v84*$C~B2jKrZUB4bJcQXguv+ag5G5Gdk-cO)xIZhEVF(-6Qf0*9MGVnjVr4~v3&5;_4-45-5*d*w>ry>mfZ4P+OHiSWqC5nT&nDfZt#PlOoEO1UX= zRNg`}6SBpqD8T~506B4!k7*&~?%g;U*Mp#|B#DGY0F)p94Z7M}Sv`q~ z!92&mbwwj|U{1yy{rXINEIEMNSRdjI_i1B4zYT6|4UL^LWfBCds;c~l4?k?u$9W>x z7@Rw&KE9t&NiQn;-r9NwwLKnOfrmsKkHW}~137NdMfHFG{i86 zg_yjFi(JFOEIP}4WD}DK`j<%dkXphaOa2?~ufklpGBhFVJo27x)zy|TMOd}*E_p~x zKYg-7h_T;TtowM7nmR2bb1|^j zK&&pVuD~spzpLK6BNF&YuS^rkLkhTwZp_n>YE8{Zybgn3z1%#bd% zWSS}BG(>%LQ=l;KcZYl#S+KsYZfGO{BMg}AY^p>Aj9(%7qKPEH`CK2 zp;#=6^CM(4ZrXI!(J}t$weMh+ITPoK`ggB{>RULGEJnn3=T*;UJ^XV1EV){>E&AvPs5Vs(9s z<ZRee zG}K(j&Bo}~czAer?vysKAaYb$DZM_d;#M+HL?GPn?&)r8dk2h%)CVp)=NmVv z))E_!hvo#|9Vnn?5l-jdK#G@v;c5@f8?_Ls(?Sr5Y$ROLb4rv9g>EnrdIOf>6 zIFukj4^{2kFmHsw5!_LRf8c=(=!al*XdXIU;6ufmf5_fzDLBa5M)j;<&F$a6rE=%Qllh zJ1cwn@(>gz`3`t67C=iwT^;#p2wGZCk6z;_4gfNQ{x;W(on(ICzyas~rkn>~nJu$( zrBbXqw7k$pxoW*A^1orA>piKM*R;Z`u|GCfY}CmAX*9rr#i5I2YRWBoY7*QPH&}jgCv&T^~nSz20Ph1KK2?5;&`D!-$5G6;q z_&;zRhKNt`z<~*9cT~KGCZKVrCI)hR+-`Qg83|13>rmV@iUG;C8_+v6!UZ)!;;F==V(6Y&abz>IvJ z3HWFb=Ging{pN4b2K+{vMgFa+(LxO62+fBE+fC7uy21<0J?PT(e67xre{8lSb6oNi z!x?o^8=0lDVPVU<-NS0_FMVD5bXQ)r{yU)zEn79$iId2 zgZ5`YMTPZ!(^GfhTV;1P+!i_KH`pDXvWA zf)NGwoy^h}k(4}_b{6W@b8o(Z3Dio|LuBXR+dFHIs`_G8J<|( zuCvn&lJH!sulbIB8H=*e+JKY(M8=r|J{UBBytSjq)$Y`(MvU5V9lNe3`kSN}84)24 z77leLa2{gD-3R=;Tu9Y;R)>^rn+<4f&k}32VPayspNU(~do+GIf)_)Bf;6w}vXS7B z@8$7GNKM6*aSpWPGbM8kCQ zX7PePY9kWODCQ7L`~w2EGpHj>T!;%W#glv-|G%of1T4lie)}2GaA-=>LMaEKU7<|T z$Nw%iN6pf{0DMEy1U}m#%Czzf=$D(;bDnbqf|agQBYQ< zs0+kB*mVG*RAVirItJhvK&h@5VqBnKDDlzk-(VBzngI)?qM{;1NM)qXBDcM{v5^C` z6gl2}8N)^`#m%WGxj_!#hX@EjIgU;M!Pta#$2>KV{5GSQ2zSwD^PQZX0T-(u`KsS0 zd)4~b2!+f~0=lMT=4a35cvSC&ISVF4oyS@OhQPiJPzIwZ2)9r_aydMlMXQx=BgROd zk?nUw5)bjZtGbPzo*s^8tlG_+PXVGxPd|jn&EC!qr8gD|76%Xhz^RMl1^@;m{isEW zH(iCJSAS~7WH%ro;OkrOsH1iq(8_5+A*8Q+cW5s0#q%LI#Z0Wcp>bAuila|AdXKEr z(Z4mg#8GV3>eUxzvt(JTdb!#UU2+FFghmCGdynpKYXHf95+-kxG^`5-`1 z?2j^nN)*%-Y}!33bcPXrD>*qCN$~9r-UuoT*y4X3dp>Uzu1Cc|jygNhYykFD=>hn; zbLVDXS@is-5RG2p)X2lk50a}Po5qF&W7#T+~79z62ZlerxK@+m{va-20#YQqW_UwI8j{ z{V5x{dlIqEMmW+u9>4h7X1=J1jLdNe2(e<%GyaqkTC|p;c|a&&J}vws*evD1`=nA( z+OXy5?`d1>jQRjwn{00HEWD5(0$M+ z&@J<|{MDI-?Bo5VE8!Nffz21uz8l(xi3r>m>!~?tDi5Z9Aceu6euRBlb#4+!BUkKb zk&CBFLZqBsUw>IpvxR;=8DTTK=fkfC#~^Gvd3+^t!#idZ12sFg-suN%jCj;ym00@o z{tHO__;#X)ksM;h^kJY;vIsYY)@-jDL;MW(8#s^zje0T#&*HFm87JuU+9I0nw5`sNsbNuYN zb3JR5FZ1o>;J@zb#1brHq}ci^H)mJV-Jx* z6`pv|)DNDZ6wJlL_U4!>+Fh_?FQJ*nqQEK;i=8$4%U|mG==ZdUG1fs2XB60sLe6*r zD&7KPN-A(Y@{{>MjP0^KF7xH83h1sGh~dDR2X9_;^qWzQF0&UoEW%NDsHqYKOp3yE zCsE9ODrI>ip?z4iSqb%VX0016cGV;qsQaXnoQPn@V;>EfZ2hN zw3aDbT=Kj@BC0>1U}-5tVt0!?B$tC{?eJ4kP@oJ&hWkDTEK^oS-gYa|Ia6GjBP^+g zM^247KJfF}6jf8q;neGhnn=mWeF3isW1C0LC5&o4xx&-xBOTF)aIh#ev+8@plQ9YF zZU;WnQGb+ax9L^mQ=eF-(0yKCT3tg~*-%uJ>~p_9*+hf4HW@AZI>U9f+Nm}oLkB4R z*V{?Aa6u+31I0T-{A$Xg)a~(kk-EXSd8ms-MvuouHS>e7?V1OhW}oObhK>dK8k(B%(1HvN2~RyY zPxJEr-4Bq~A&(=asAVE1BUhXbMPs99Izaq4I67?Pj6re01qR__(3MhOO|Q`r)vNom zAD=(3o3pQG%(%~2xuT*%h~bdy7TS8HD5;t@0~{wbFRNisgYzgMlhtSdm!@IwP6zKJ z$Hv11lkQA}-|lmF&q;GF_YLb6+p3~sDJEw3X$I@9eQsdx z#I8;2hx4POP2{8D?lxdtqu+bLE8X1%G3v{X)1{f>GKF6O*d9Ybf~F>%E|hlZ{HD{} z+gz|g%Y5?QxUqX+qt0*)%t&A}7_~BVsnZ5Yy5uKdW^Qk9)(sdhc*63jyHeiQJuub2 zPTdA&-4JmOESZCSP_ho2H@NqWaL}yQbCSQgc3*^p<=vOTer4sc?eOoh}^9+e4Zz}s<-w=rt2-$~C56QrfGejANt zK$oPpYwY4!8}iKA72#1)4V{ZM|7psf2v{gC_`}kCW^?{T*papor}+wNBwP*y&9Q&V zhHHR6V$Ln@CqFih4mQ+G{;rG^S`>2MajI02sDHd%RMeoMZh-XQ)ElF;m3eE)Tz$G^ z6V68tPEWu@=(i5=NE=(w$Rzqybw0rb#h`l~z~$>QztE^J77Fc8Eg$3K%=}NZsPt@426!29LL4p9%eX z-9zUTLoE*eevq9YOYrTk|C=pt9fQyWcCrm=R(<AuC zE5C|T<71dU`%%S*xPtoAu7{e1a_CX$z2LF?qo+rSIhW`xlyMSCv*dv=3T%(G(x=n7 z?t3<}rVcheY3=I^LUlSs5f|RgAeEcxe8blE^T=4>)Pr~b^~0c1aTBJAi< zd^z&7Qp*Rrd?7ZLqkV9X#EMRws|V3dG7W5vpq zSlYzm^%v8LDvz}wauArk@;Tbv|g=M@Gh)B#^zR61AdkvNX^f(I0%dRyD?Ai`600V{hABY8do0GP5NLq!yR z9#2+hK(Bp+pbsWKZ{-MAznb=Ro3Zh+)?f*o^WZeLxYLMzhNih;B-CvK)rMk(bR`I| zvk*dL+irBLki^YceigI2{?yQ@sjO60QF(y$E+La}fAEQY2%K0E*yEX~28348>V~Lz zWOB3{!vcN^*Ret~mY8k?I);B{568<{q;MXt%Y+6E;SmyYYqkif66iRU+KOoo${CQh z*ZZXNd?gY`XR|E!WjEqcg3p`*h^~jes1nAd`w_=LhlCtpfy_)Y@XrxeB#^9;(a{so zv_XkMweL$6&7!|}@gXE~I1UH1%n3!dla*`r_ zbe$JqvmmNLZn%X`BQN)EY}pBt5Az0D{6nOR=oTkyc44CBed}tJ&_mro%`d?cPT`c% zjvbJl0tEB;_~;IG-Qf^{$&xd37Qo$fD~)7`+}GH+CKg0aVb!Ea z-F=R?-qcQ0)0Xb;#S0hKV<#=K>b)B%cp85?FUu=uWS$=IM zWZTP|874eSWXn`Ii@HIo_jgv)9>L-+3o|i2e>4?y+v&iy{3VAmaM&gQ>1W|zQx`t8 z;n3&FFl(EeabD(9-9^OnM80^lGn5urYYt)b*X$7b{=>K$yUggCv3U$LuN5$Sz_zPp*;g5_FOaUmXn*G~^ zsKHgcxLX*=7yrY(i}9$=!pQru#q*+ zUlAbh_8YeJ`L5I3@FS(}5FohpEf@2S92xvxr{T7ls2vf#wd97TXMNwphb1>PgA8bI zT-5K`W+y;G7tme*CR`-cz&XpGIE^MX(kT1K(Tl1Hi#4xV?YY*xRepC(@@0h;?QQMl zq2WR#X>)|Po-J>EVadvbEcW$G@4ED-=4PV)|91RNvjT`RFG`3k*Nk@ieJ84J9sQkrRcx93EOJG@lraMGZQtjoVdHANtN}M08=i){ zi@Hs_sA0XcOpA5eHCvvFlG4@kc7w7%h`dRJjA^`O1Lw!(6uBF_c($=Uy*AHrh-D?1G^28F|gFxli@++RzyhCYaiGLym;h_ zDZkq`c1HB&*o}xoRukgSifAJ@Ozzh0cbp@{JN~`4VDUN*yMD@e6E+>^SZ0ang5PmE zL`SZU5nk#7wL`Msp{QuhoF~t8+rDXqnylAURNTs&?E0`qkVby08NZul@hB%JOq)}i z9Te23XtlnxW78(pq|2ozlvd5grsz@M^D5UH2b<33qBbraW~^)68?$QMD~FEw9#p-u zF6&WQjmHDs;ii-0fjI`Do2;K}tzA2iTQYQh2C0Y_P(Ypz-Q3^E%C%M{r`&U|aeIr+ zuO=b~?}X^8xF_JEe#0^?J2MjuM6&&}$JEwD6Lk>wx7)t`lAX#uLgFq|q*CrZ)Nzad zEmg~5cu;w~7?G1=6Jo8B`JD^Up^8RKecbA3@Eyiq&0}qRNE;!^;+fyc&@37&`}%IW z8iR0CHPRSPDzi!I>p*544x}efK0qLg<1KvNzg^srka_U7dqpqh^NZpJ{h*H#VHW@ob;+|G~i|)7O)qnZ2 z37qRR9)X)^NYRm(!_Z+DA*RZ1?`cYcgv+Q%60n{}h|G0k`p=_;@CyT(+#`4j2^m7J zUBruG%3bAU3U8|68xtmX-5yeP5_h4VPsLkHF#(`Eav5x-fhLk@jdU?pWDQyXqh=n3 z?zl)JaRM4$L|lNd=+k}X%_k1{+0WPiqEklf)c5#PF+LaL6CK7qh&D;mskXIl+yj*( z%DrJWM{W0em+LYh(@xuU4+czx86HT_ki@6PS%%XZIPxs}&2VyFMffKDGKF>`gkK@e zEPVU+Z}uM~eEI)#k9xF4bGj|1epG+4+~03f1VwD85`}Ghw~cXr1t}a!n%DvNth?}y z#;u-S9xAwn{$l-Y^o;5+?#*!D>Hexc^G%!TFMbHsRa?;N_6T2khWtFTZfP&W>2FE0 ze{*Bs97Pm}QY&D3)^6j-6z0qQcNsOjj<`bz#$0uIJq-c_ zTb`u~mA8`))?S&z5TeRYUsFH59+nuVPp^mLm#yg;YsB!DqP;TVH*OeXVDpXcPfyDw z#`guR7YLG#5f0S4CdcEaho>vLe^F@hyLS)Z_}tiNOGS!wvqf|lSJw^N+E{XNs%GDb zI3BnF&7hXv65|vBSXhvOZCn4Fu2HZ4&vaFj744pWGDCb)(={xqmUe%R0Qba=fiX+? zx*$;~d;1oausD{W5i>;4vE}Rx^6#Fb?pt8^)C0R&6_p)E~Z7bG4Kx`H!U^kX%m zU)B7Py`b+vjtUKnk#MGSIvoP7Vw4-i&62xP3N5}GXMhjdvoLN4jihLN*P>1lAbrTX zhX)7O4P=7Pfr%taa=LJu^WAA&eJLO546je?gMxyViBVmPe7c*L19Hdf%>RB}1!@GU z4~T_d6d$vLcGJnJ1z7QY-ruAC=0xPRe;XHg6QnIbwifrE$5#{`gR+s3A$<}q-sV$? zz?(Dp-PPGS@ntq>9&}D1#Ns}xJ(yznj_v;^v7UkJkRY+O7bHLc2EiPVBj)po#gdZO z!^7iV{zWSkl}r~vRUkw~0WvCZ8}q13jXWsq)$rgmKzg5N;-uo%FJ=p6D!F%osCS8I6)xM7EDNBYl6obRqR*9j9=6ED)Eh5{swc_z`@EjGL11KiYh_Ff~RQ;@wZ# zNFr+YXV=NhL_&d*g|}KI3us(1sdNFV`^T+(gC>7>qHT+R#}m1Q#o8L%d6ocmwU#b@ zfjOkL1JY*X=f+^Kk$#8%J3V@#0@!Ksj@IG#L9r-AKvM93oquZ<5q0f3RCe#~ z-J!0Ul^ZsoLZ$BT4#;Pz;&13Ku=mXSZV z7h9jGsXD1PL?^BVycbQ%Ca;caKh-cQ06 zYQA&n|0%|gf@K3Q2z|cT!iA`_s!y}3#9aFH=^3^Grs@(xmZ(+Z2$VcIf*_Q;>i<87 rr=!Q|kZ3xPsHPF8Ts|Qgc2o2o-K Date: Sat, 3 Aug 2024 15:16:24 +0800 Subject: [PATCH 3/8] add tests for dependency check --- test/visualization.jl | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/test/visualization.jl b/test/visualization.jl index 590a036..77c5319 100644 --- a/test/visualization.jl +++ b/test/visualization.jl @@ -1,6 +1,31 @@ +using OMEinsumContractionOrders: ein2hypergraph, ein2elimination + +# tests before the extension loaded +@testset "luxor tensor plot dependency check" begin + @test_throws ArgumentError begin + eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a']) + ein2hypergraph(eincode) + end + + @test_throws ArgumentError begin + eincode = OMEinsum.rawcode(ein"((ij, jk), kl), lm -> im") + ein2elimination(eincode) + end + + @test_throws ArgumentError begin + eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}()) + viz_eins(eincode) + end + + @test_throws ArgumentError begin + eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}()) + nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod()) + viz_contraction(nested_code, pathname = "") + end +end + using LuxorGraphPlot using LuxorGraphPlot.Luxor -using OMEinsumContractionOrders: ein2hypergraph, ein2elimination @testset "eincode to hypergraph" begin eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a']) From 5d073b64bcfbd9f7e73d464f086d313628201f61 Mon Sep 17 00:00:00 2001 From: ArrogantGao Date: Sat, 3 Aug 2024 15:29:27 +0800 Subject: [PATCH 4/8] revise error message --- src/visualization.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/visualization.jl b/src/visualization.jl index 6f0e5b7..bcff1a1 100644 --- a/src/visualization.jl +++ b/src/visualization.jl @@ -1,15 +1,15 @@ function ein2hypergraph(args...; kwargs...) - throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`")) + throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`.")) end function ein2elimination(args...; kwargs...) - throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`")) + throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`.")) end function viz_eins(args...; kwargs...) - throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`")) + throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`.")) end function viz_contraction(args...; kwargs...) - throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`")) + throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`.")) end \ No newline at end of file From 4f1cd73df16c3bd1b5b53bfef7b93aaeb2d53b2b Mon Sep 17 00:00:00 2001 From: ArrogantGao Date: Sun, 4 Aug 2024 12:48:55 +0800 Subject: [PATCH 5/8] update the interface of viz contraction --- ext/LuxorTensorPlot/src/viz_contraction.jl | 73 +++++++++++----------- test/visualization.jl | 23 +++++-- 2 files changed, 52 insertions(+), 44 deletions(-) diff --git a/ext/LuxorTensorPlot/src/viz_contraction.jl b/ext/LuxorTensorPlot/src/viz_contraction.jl index aa03763..0027571 100644 --- a/ext/LuxorTensorPlot/src/viz_contraction.jl +++ b/ext/LuxorTensorPlot/src/viz_contraction.jl @@ -48,36 +48,38 @@ Visualize the contraction process of a tensor network. - `ein::ET`: The tensor network to visualize. - `locs`: The layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`. - `framerate`: The frame rate of the animation. Default is 30. -- `filename`: The base name of the output files. Default is "contraction". -- `pathname`: The directory path to save the output files. Default is the current directory. -- `create_gif`: Whether to create a GIF animation. Default is `false`. -- `create_video`: Whether to create a video. Default is `true`. +- `filename`: The name of the output files. If the filename ends with ".gif", a GIF file will be generated. If the filename ends with ".mp4", a video file will be generated. If the filename contains path, the file will be saved to the path, or the file will be saved to the tempdirectory. Default is "contraction.mp4". - `color`: The color of the contraction lines. Default is `(0.5, 0.5, 0.5, 0.5)`. - `show_progress`: Whether to show progress information. Default is `false`. # Returns -- If `create_gif` is `true`, returns the path to the generated GIF animation. -- If `create_video` is `true`, returns the path to the generated video. +- the path of the GIF or video file. """ function OMEinsumContractionOrders.viz_contraction( ein::ET; locs=StressLayout(), - framerate = 30, - filename = "contraction", - pathname = ".", - create_gif = false, - create_video = true, + framerate = 10, + filename = "contraction.mp4", color = (0.5, 0.5, 0.5, 0.5), show_progress::Bool = false ) where{ET <: Union{NestedEinsum, SlicedEinsum}} + # analyze the output format + paths = splitpath(filename) + file_name = paths[end] + path_name = length(paths) > 1 ? joinpath(paths[1:end-1]...) : "" + format = splitext(file_name)[end] + tempdirectory = mktempdir() + + if format != ".gif" && format != ".mp4" + throw(ArgumentError("Unsupported format $format, only .gif and .mp4 are supported")) + end + + # generate the frames elimination_order = ein2elimination(ein) tng = TensorNetworkGraph(ein2hypergraph(ein)) GViz = GraphViz(tng, locs) - tempdirectory = mktempdir() - # @info("Frames for animation \"$(filename)\" are being stored in directory: \n\t $(tempdirectory)") - filecounter = 1 le = length(elimination_order) @info "Generating frames, $(le + 1) frames in total" @@ -90,36 +92,33 @@ function OMEinsumContractionOrders.viz_contraction( filecounter += 1 end - if create_gif - Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $(tempdirectory)/$(filename).gif`) + if format == ".gif" + Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $(tempdirectory)/$(file_name)`) - if !isempty(pathname) - if !isdir(pathname) - @error "$pathname is not a directory." + if !isempty(path_name) + if !isdir(path_name) + @error "$path_name is not a directory." end - fig_path = joinpath(pathname, "$filename.gif") - mv("$(tempdirectory)/$(filename).gif", fig_path, force = true) - @info("GIF is: $fig_path") - giffn = fig_path + mv("$(tempdirectory)/$(file_name)", filename, force = true) + @info("GIF is: $filename") + giffn = filename else - @info("GIF is: $(tempdirectory)/$(filename).gif") - giffn = tempdirectory * "/" * filename * ".gif" + @info("GIF is: $(tempdirectory)/$(file_name)") + giffn = tempdirectory * "/" * file_name end return giffn - elseif create_video - movieformat = ".mp4" - - if !isempty(pathname) - if !isdir(pathname) - @error "$pathname is not a directory." + elseif format == ".mp4" + if !isempty(path_name) + if !isdir(path_name) + @error "$path_name is not a directory." end - pathname = joinpath(pathname, "$(filename)$(movieformat)") + video_path = filename else - pathname = joinpath("$(tempdirectory)", "$(filename)$(movieformat)") + video_path = joinpath("$(tempdirectory)", "$(file_name)") end - @info "Creating video at: $pathname" + @info "Creating video at: $video_path" FFMPEG.ffmpeg_exe(` -loglevel panic -r $(framerate) @@ -129,10 +128,8 @@ function OMEinsumContractionOrders.viz_contraction( -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" -r $(framerate) -pix_fmt yuv420p - -y $(pathname)`) + -y $(video_path)`) - return pathname - else - return tempdirectory + return video_path end end \ No newline at end of file diff --git a/test/visualization.jl b/test/visualization.jl index 77c5319..bc19acf 100644 --- a/test/visualization.jl +++ b/test/visualization.jl @@ -1,3 +1,4 @@ +using OMEinsum using OMEinsumContractionOrders: ein2hypergraph, ein2elimination # tests before the extension loaded @@ -69,20 +70,30 @@ end @testset "visualize contraction" begin eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}()) nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod()) - t_mp4 = viz_contraction(nested_code, pathname = "") + t_mp4 = viz_contraction(nested_code) + t_mp4_2 = viz_contraction(nested_code, filename = "test.mp4") @test typeof(t_mp4) == String - t_gif = viz_contraction(nested_code, pathname = "", create_gif = true) + @test typeof(t_mp4_2) == String + t_gif = viz_contraction(nested_code, filename = "test.gif") @test typeof(t_gif) == String + @test_throws ArgumentError begin + viz_contraction(nested_code, filename = "test.avi") + end + sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA()) - t_mp4 = viz_contraction(sliced_code, pathname = "") + t_mp4 = viz_contraction(sliced_code) + t_mp4_2 = viz_contraction(sliced_code, filename = "test.mp4") @test typeof(t_mp4) == String - t_gif = viz_contraction(sliced_code, pathname = "", create_gif = true) + @test typeof(t_mp4_2) == String + t_gif = viz_contraction(sliced_code, filename = "test.gif") @test typeof(t_gif) == String sliced_code2 = optimize_code(eincode, uniformsize(eincode, 2), TreeSA(nslices = 1)) - t_mp4 = viz_contraction(sliced_code2, pathname = "") + t_mp4 = viz_contraction(sliced_code2) + t_mp4_2 = viz_contraction(sliced_code2, filename = "test.mp4") @test typeof(t_mp4) == String - t_gif = viz_contraction(sliced_code2, pathname = "", create_gif = true) + @test typeof(t_mp4_2) == String + t_gif = viz_contraction(sliced_code2, filename = "test.gif") @test typeof(t_gif) == String end \ No newline at end of file From 932a2017e10b74605106ab49072e4ac90dac3c8d Mon Sep 17 00:00:00 2001 From: ArrogantGao Date: Sun, 4 Aug 2024 13:02:48 +0800 Subject: [PATCH 6/8] update readme --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 5fb9f7f..dfef18c 100644 --- a/README.md +++ b/README.md @@ -121,10 +121,10 @@ ab, ab -> a └─ e -julia> viz_contraction(nested_eins) -[ Info: Generating frames, 5 frames in total -[ Info: Creating video at: ./contraction.mp4 -"./contraction.mp4" +julia> viz_contraction(nested_code) +[ Info: Generating frames, 7 frames in total +[ Info: Creating video at: /var/folders/3y/xl2h1bxj4ql27p01nl5hrrnc0000gn/T/jl_SiSvrH/contraction.mp4 +"/var/folders/3y/xl2h1bxj4ql27p01nl5hrrnc0000gn/T/jl_SiSvrH/contraction.mp4" ``` The resulting image and video will be saved in the current working directory, and the image is shown below: From 843759ffc5ec712a8b08cd09f7cdd39c882efe9c Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Sun, 4 Aug 2024 15:35:32 +0800 Subject: [PATCH 7/8] simplify --- ext/LuxorTensorPlot/src/viz_contraction.jl | 125 ++++++++------------- ext/LuxorTensorPlot/src/viz_eins.jl | 26 +++-- test/visualization.jl | 51 +++++---- 3 files changed, 88 insertions(+), 114 deletions(-) diff --git a/ext/LuxorTensorPlot/src/viz_contraction.jl b/ext/LuxorTensorPlot/src/viz_contraction.jl index 0027571..dcc0d61 100644 --- a/ext/LuxorTensorPlot/src/viz_contraction.jl +++ b/ext/LuxorTensorPlot/src/viz_contraction.jl @@ -1,23 +1,23 @@ -function OMEinsumContractionOrders.ein2elimination(ein::NestedEinsum{T}) where{T} +function OMEinsumContractionOrders.ein2elimination(code::NestedEinsum{T}) where{T} elimination_order = Vector{T}() - _ein2elimination!(ein, elimination_order) + _ein2elimination!(code, elimination_order) return elimination_order end -function OMEinsumContractionOrders.ein2elimination(ein::SlicedEinsum{T, NestedEinsum{T}}) where{T} +function OMEinsumContractionOrders.ein2elimination(code::SlicedEinsum{T, NestedEinsum{T}}) where{T} elimination_order = Vector{T}() - _ein2elimination!(ein.eins, elimination_order) + _ein2elimination!(code.eins, elimination_order) # the slicing indices are eliminated at the end - return vcat(elimination_order, ein.slicing) + return vcat(elimination_order, code.slicing) end -function _ein2elimination!(ein::NestedEinsum{T}, elimination_order::Vector{T}) where{T} - if ein.tensorindex == -1 - for arg in ein.args +function _ein2elimination!(code::NestedEinsum{T}, elimination_order::Vector{T}) where{T} + if code.tensorindex == -1 + for arg in code.args _ein2elimination!(arg, elimination_order) end - iy = unique(vcat(getiyv(ein.eins)...)) - for ix in unique(vcat(getixsv(ein.eins)...)) + iy = unique(vcat(getiyv(code.eins)...)) + for ix in unique(vcat(getixsv(code.eins)...)) if !(ix in iy) && !(ix in elimination_order) push!(elimination_order, ix) end @@ -26,100 +26,63 @@ function _ein2elimination!(ein::NestedEinsum{T}, elimination_order::Vector{T}) w return elimination_order end -function elimination_frame(GViz, tng::TensorNetworkGraph{TG, TL}, elimination_order::Vector{TL}, i::Int; filename = nothing, color = (0.5, 0.5, 0.5, 0.5)) where{TG, TL} - GViz2 = deepcopy(GViz) +function elimination_frame(gviz::GraphViz, tng::TensorNetworkGraph{TG, TL}, elimination_order::Vector{TL}, i::Int; filename = nothing) where{TG, TL} + gviz2 = deepcopy(gviz) for j in 1:i id = _get_key(tng.indices_labels, elimination_order[j]) - GViz2.vertex_colors[id] = color + gviz2.vertex_colors[id] = (0.5, 0.5, 0.5, 0.5) end - return show_graph(GViz2, filename = filename) + return show_graph(gviz2, filename = filename) end -function OMEinsumContractionOrders.viz_contraction(ein::T, args...; kwargs...) where{T <: AbstractEinsum} +function OMEinsumContractionOrders.viz_contraction(code::T, args...; kwargs...) where{T <: AbstractEinsum} throw(ArgumentError("Only NestedEinsum and SlicedEinsum{T, NestedEinsum{T}} have contraction order")) end """ - viz_contraction(ein::ET; locs=StressLayout(), framerate=30, filename="contraction", pathname=".", create_gif=false, create_video=true, color=(0.5, 0.5, 0.5, 0.5), show_progress=false) where {ET <: Union{NestedEinsum, SlicedEinsum}} + viz_contraction(code::Union{NestedEinsum, SlicedEinsum}; locs=StressLayout(), framerate=10, filename=tempname() * ".mp4", show_progress=true) Visualize the contraction process of a tensor network. -# Arguments -- `ein::ET`: The tensor network to visualize. -- `locs`: The layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`. -- `framerate`: The frame rate of the animation. Default is 30. -- `filename`: The name of the output files. If the filename ends with ".gif", a GIF file will be generated. If the filename ends with ".mp4", a video file will be generated. If the filename contains path, the file will be saved to the path, or the file will be saved to the tempdirectory. Default is "contraction.mp4". -- `color`: The color of the contraction lines. Default is `(0.5, 0.5, 0.5, 0.5)`. -- `show_progress`: Whether to show progress information. Default is `false`. +### Arguments +- `code`: The tensor network to visualize. + +### Keyword Arguments +- `locs`: The coordinates or layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`. +- `framerate`: The frame rate of the animation. Default is `10`. +- `filename`: The name of the output file, with `.gif` or `.mp4` extension. Default is a temporary file with `.mp4` extension. +- `show_progress`: Whether to show progress information. Default is `true`. # Returns -- the path of the GIF or video file. +- the path of the generated file. """ function OMEinsumContractionOrders.viz_contraction( - ein::ET; - locs=StressLayout(), - framerate = 10, - filename = "contraction.mp4", - color = (0.5, 0.5, 0.5, 0.5), - show_progress::Bool = false - ) where{ET <: Union{NestedEinsum, SlicedEinsum}} + code::Union{NestedEinsum, SlicedEinsum}; + locs=StressLayout(), + framerate = 10, + filename::String = tempname() * ".mp4", + show_progress::Bool = true) # analyze the output format - paths = splitpath(filename) - file_name = paths[end] - path_name = length(paths) > 1 ? joinpath(paths[1:end-1]...) : "" - format = splitext(file_name)[end] + @assert endswith(filename, ".gif") || endswith(filename, ".mp4") "Unsupported file format: $filename, only :gif and :mp4 are supported" tempdirectory = mktempdir() - if format != ".gif" && format != ".mp4" - throw(ArgumentError("Unsupported format $format, only .gif and .mp4 are supported")) - end - # generate the frames - elimination_order = ein2elimination(ein) - tng = TensorNetworkGraph(ein2hypergraph(ein)) - GViz = GraphViz(tng, locs) + elimination_order = ein2elimination(code) + tng = TensorNetworkGraph(ein2hypergraph(code)) + gviz = GraphViz(tng, locs) - filecounter = 1 le = length(elimination_order) - @info "Generating frames, $(le + 1) frames in total" for i in 0:le - if show_progress - @info "Frame $(i + 1) of $(le + 1)" - end - fig_name = "$(tempdirectory)/$(lpad(filecounter, 10, "0")).png" - elimination_frame(GViz, tng, elimination_order, i; filename = fig_name, color = color) - filecounter += 1 + show_progress && @info "Frame $(i + 1) of $(le + 1)" + fig_name = "$(tempdirectory)/$(lpad(i+1, 10, "0")).png" + elimination_frame(gviz, tng, elimination_order, i; filename = fig_name) end - if format == ".gif" - Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $(tempdirectory)/$(file_name)`) - - if !isempty(path_name) - if !isdir(path_name) - @error "$path_name is not a directory." - end - mv("$(tempdirectory)/$(file_name)", filename, force = true) - @info("GIF is: $filename") - giffn = filename - else - @info("GIF is: $(tempdirectory)/$(file_name)") - giffn = tempdirectory * "/" * file_name - end - - return giffn - elseif format == ".mp4" - if !isempty(path_name) - if !isdir(path_name) - @error "$path_name is not a directory." - end - video_path = filename - else - video_path = joinpath("$(tempdirectory)", "$(file_name)") - end - - @info "Creating video at: $video_path" - FFMPEG.ffmpeg_exe(` + if endswith(filename, ".gif") + Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $filename`) + else + Luxor.FFMPEG.ffmpeg_exe(` -loglevel panic -r $(framerate) -f image2 @@ -128,8 +91,8 @@ function OMEinsumContractionOrders.viz_contraction( -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" -r $(framerate) -pix_fmt yuv420p - -y $(video_path)`) - - return video_path + -y $filename`) end + show_progress && @info "Generated output at: $filename" + return filename end \ No newline at end of file diff --git a/ext/LuxorTensorPlot/src/viz_eins.jl b/ext/LuxorTensorPlot/src/viz_eins.jl index 58b6859..43c6ff6 100644 --- a/ext/LuxorTensorPlot/src/viz_eins.jl +++ b/ext/LuxorTensorPlot/src/viz_eins.jl @@ -43,9 +43,9 @@ function _get_key(dict::Dict, value) @error "Value not found in dictionary" end -function OMEinsumContractionOrders.ein2hypergraph(ec::T) where{T <: AbstractEinsum} - ixs = getixsv(ec) - iy = getiyv(ec) +function OMEinsumContractionOrders.ein2hypergraph(code::T) where{T <: AbstractEinsum} + ixs = getixsv(code) + iy = getiyv(code) edges = unique!([Iterators.flatten(ixs)...]) open_edges = [iy[i] for i in 1:length(iy) if iy[i] in edges] @@ -62,19 +62,21 @@ function OMEinsumContractionOrders.ein2hypergraph(ec::T) where{T <: AbstractEins end """ - viz_eins(ec::AbstractEinsum; locs=StressLayout(), filename = nothing, kwargs...) + viz_eins(code::AbstractEinsum; locs=StressLayout(), filename = nothing, kwargs...) Visualizes an `AbstractEinsum` object by creating a tensor network graph and rendering it using GraphViz. -## Arguments -- `ec::AbstractEinsum`: The `AbstractEinsum` object to visualize. -- `locs=StressLayout()`: The layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`. -- `filename = nothing`: The name of the file to save the visualization to. If `nothing`, the visualization will be displayed on the screen instead of saving to a file. -- `kwargs...`: Additional keyword arguments to be passed to the `GraphViz` constructor. +### Arguments +- `code::AbstractEinsum`: The `AbstractEinsum` object to visualize. +### Keyword Arguments +- `locs=StressLayout()`: The coordinates or layout algorithm to use for positioning the nodes in the graph. +- `filename = nothing`: The name of the file to save the visualization to. If `nothing`, the visualization will be displayed on the screen instead of saving to a file. +- `config = GraphDisplayConfig()`: The configuration for displaying the graph. Please refer to the documentation of [`GraphDisplayConfig`](https://giggleliu.github.io/LuxorGraphPlot.jl/dev/ref/#LuxorGraphPlot.GraphDisplayConfig) for more information. +- `kwargs...`: Additional keyword arguments to be passed to the [`GraphViz`](https://giggleliu.github.io/LuxorGraphPlot.jl/dev/ref/#LuxorGraphPlot.GraphViz) constructor. """ -function OMEinsumContractionOrders.viz_eins(ec::AbstractEinsum; locs=StressLayout(), filename = nothing, kwargs...) - tng = TensorNetworkGraph(ein2hypergraph(ec)) +function OMEinsumContractionOrders.viz_eins(code::AbstractEinsum; locs=StressLayout(), filename = nothing, config=LuxorTensorPlot.GraphDisplayConfig(), kwargs...) + tng = TensorNetworkGraph(ein2hypergraph(code)) gviz = GraphViz(tng, locs; kwargs...) - return show_graph(gviz, filename = filename) + return show_graph(gviz; filename, config) end \ No newline at end of file diff --git a/test/visualization.jl b/test/visualization.jl index bc19acf..eda62aa 100644 --- a/test/visualization.jl +++ b/test/visualization.jl @@ -1,5 +1,6 @@ using OMEinsum using OMEinsumContractionOrders: ein2hypergraph, ein2elimination +using Test, OMEinsumContractionOrders # tests before the extension loaded @testset "luxor tensor plot dependency check" begin @@ -52,48 +53,56 @@ end @testset "visualize eincode" begin eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}()) t = viz_eins(eincode) - @test typeof(t) == Luxor.Drawing + @test t isa Luxor.Drawing nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod()) t = viz_eins(nested_code) - @test typeof(t) == Luxor.Drawing + @test t isa Luxor.Drawing sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA()) t = viz_eins(sliced_code) - @test typeof(t) == Luxor.Drawing + @test t isa Luxor.Drawing open_eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a']) t = viz_eins(open_eincode) - @test typeof(t) == Luxor.Drawing + @test t isa Luxor.Drawing + + # filename and location specified + eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}()) + filename = tempname() * ".png" + viz_eins(eincode; filename, locs=vcat([(randn() * 60, 0.0) for i=1:5], [(randn() * 60, 320.0) for i=1:6])) + @test isfile(filename) end @testset "visualize contraction" begin eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}()) nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod()) t_mp4 = viz_contraction(nested_code) - t_mp4_2 = viz_contraction(nested_code, filename = "test.mp4") - @test typeof(t_mp4) == String - @test typeof(t_mp4_2) == String - t_gif = viz_contraction(nested_code, filename = "test.gif") - @test typeof(t_gif) == String - - @test_throws ArgumentError begin + tempmp4 = tempname() * ".mp4" + tempgif = tempname() * ".gif" + t_mp4_2 = viz_contraction(nested_code, filename = tempmp4) + @test t_mp4 isa String + @test t_mp4_2 isa String + t_gif = viz_contraction(nested_code, filename = tempgif) + @test t_gif isa String + + @test_throws AssertionError begin viz_contraction(nested_code, filename = "test.avi") end sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA()) t_mp4 = viz_contraction(sliced_code) - t_mp4_2 = viz_contraction(sliced_code, filename = "test.mp4") - @test typeof(t_mp4) == String - @test typeof(t_mp4_2) == String - t_gif = viz_contraction(sliced_code, filename = "test.gif") - @test typeof(t_gif) == String + t_mp4_2 = viz_contraction(sliced_code, filename = tempmp4) + @test t_mp4 isa String + @test t_mp4_2 isa String + t_gif = viz_contraction(sliced_code, filename = tempgif) + @test t_gif isa String sliced_code2 = optimize_code(eincode, uniformsize(eincode, 2), TreeSA(nslices = 1)) t_mp4 = viz_contraction(sliced_code2) - t_mp4_2 = viz_contraction(sliced_code2, filename = "test.mp4") - @test typeof(t_mp4) == String - @test typeof(t_mp4_2) == String - t_gif = viz_contraction(sliced_code2, filename = "test.gif") - @test typeof(t_gif) == String + t_mp4_2 = viz_contraction(sliced_code2, filename = tempmp4) + @test t_mp4 isa String + @test t_mp4_2 isa String + t_gif = viz_contraction(sliced_code2, filename = tempgif) + @test t_gif isa String end \ No newline at end of file From f4239607ab54b875dc6da2ce048b33dea4f1483b Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Sun, 4 Aug 2024 15:39:48 +0800 Subject: [PATCH 8/8] fix file path --- ext/LuxorTensorPlot/src/viz_contraction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/LuxorTensorPlot/src/viz_contraction.jl b/ext/LuxorTensorPlot/src/viz_contraction.jl index dcc0d61..12af417 100644 --- a/ext/LuxorTensorPlot/src/viz_contraction.jl +++ b/ext/LuxorTensorPlot/src/viz_contraction.jl @@ -75,7 +75,7 @@ function OMEinsumContractionOrders.viz_contraction( le = length(elimination_order) for i in 0:le show_progress && @info "Frame $(i + 1) of $(le + 1)" - fig_name = "$(tempdirectory)/$(lpad(i+1, 10, "0")).png" + fig_name = joinpath(tempdirectory, "$(lpad(i+1, 10, "0")).png") elimination_frame(gviz, tng, elimination_order, i; filename = fig_name) end