Skip to content

Commit

Permalink
update the interface of viz contraction
Browse files Browse the repository at this point in the history
  • Loading branch information
ArrogantGao committed Aug 4, 2024
1 parent 0638365 commit 4f1cd73
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 44 deletions.
73 changes: 35 additions & 38 deletions ext/LuxorTensorPlot/src/viz_contraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,38 @@ Visualize the contraction process of a tensor network.
- `ein::ET`: The tensor network to visualize.
- `locs`: The layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`.
- `framerate`: The frame rate of the animation. Default is 30.
- `filename`: The base name of the output files. Default is "contraction".
- `pathname`: The directory path to save the output files. Default is the current directory.
- `create_gif`: Whether to create a GIF animation. Default is `false`.
- `create_video`: Whether to create a video. Default is `true`.
- `filename`: The name of the output files. If the filename ends with ".gif", a GIF file will be generated. If the filename ends with ".mp4", a video file will be generated. If the filename contains path, the file will be saved to the path, or the file will be saved to the tempdirectory. Default is "contraction.mp4".
- `color`: The color of the contraction lines. Default is `(0.5, 0.5, 0.5, 0.5)`.
- `show_progress`: Whether to show progress information. Default is `false`.
# Returns
- If `create_gif` is `true`, returns the path to the generated GIF animation.
- If `create_video` is `true`, returns the path to the generated video.
- the path of the GIF or video file.
"""
function OMEinsumContractionOrders.viz_contraction(
ein::ET;
locs=StressLayout(),
framerate = 30,
filename = "contraction",
pathname = ".",
create_gif = false,
create_video = true,
framerate = 10,
filename = "contraction.mp4",
color = (0.5, 0.5, 0.5, 0.5),
show_progress::Bool = false
) where{ET <: Union{NestedEinsum, SlicedEinsum}}

# analyze the output format
paths = splitpath(filename)
file_name = paths[end]
path_name = length(paths) > 1 ? joinpath(paths[1:end-1]...) : ""
format = splitext(file_name)[end]
tempdirectory = mktempdir()

if format != ".gif" && format != ".mp4"
throw(ArgumentError("Unsupported format $format, only .gif and .mp4 are supported"))
end

# generate the frames
elimination_order = ein2elimination(ein)
tng = TensorNetworkGraph(ein2hypergraph(ein))
GViz = GraphViz(tng, locs)

tempdirectory = mktempdir()
# @info("Frames for animation \"$(filename)\" are being stored in directory: \n\t $(tempdirectory)")

filecounter = 1
le = length(elimination_order)
@info "Generating frames, $(le + 1) frames in total"
Expand All @@ -90,36 +92,33 @@ function OMEinsumContractionOrders.viz_contraction(
filecounter += 1
end

if create_gif
Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $(tempdirectory)/$(filename).gif`)
if format == ".gif"
Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $(tempdirectory)/$(file_name)`)

if !isempty(pathname)
if !isdir(pathname)
@error "$pathname is not a directory."
if !isempty(path_name)
if !isdir(path_name)
@error "$path_name is not a directory."
end
fig_path = joinpath(pathname, "$filename.gif")
mv("$(tempdirectory)/$(filename).gif", fig_path, force = true)
@info("GIF is: $fig_path")
giffn = fig_path
mv("$(tempdirectory)/$(file_name)", filename, force = true)
@info("GIF is: $filename")
giffn = filename
else
@info("GIF is: $(tempdirectory)/$(filename).gif")
giffn = tempdirectory * "/" * filename * ".gif"
@info("GIF is: $(tempdirectory)/$(file_name)")
giffn = tempdirectory * "/" * file_name
end

return giffn
elseif create_video
movieformat = ".mp4"

if !isempty(pathname)
if !isdir(pathname)
@error "$pathname is not a directory."
elseif format == ".mp4"
if !isempty(path_name)
if !isdir(path_name)
@error "$path_name is not a directory."
end
pathname = joinpath(pathname, "$(filename)$(movieformat)")
video_path = filename
else
pathname = joinpath("$(tempdirectory)", "$(filename)$(movieformat)")
video_path = joinpath("$(tempdirectory)", "$(file_name)")
end

@info "Creating video at: $pathname"
@info "Creating video at: $video_path"
FFMPEG.ffmpeg_exe(`
-loglevel panic
-r $(framerate)
Expand All @@ -129,10 +128,8 @@ function OMEinsumContractionOrders.viz_contraction(
-vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"
-r $(framerate)
-pix_fmt yuv420p
-y $(pathname)`)
-y $(video_path)`)

return pathname
else
return tempdirectory
return video_path
end
end
23 changes: 17 additions & 6 deletions test/visualization.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using OMEinsum
using OMEinsumContractionOrders: ein2hypergraph, ein2elimination

# tests before the extension loaded
Expand Down Expand Up @@ -69,20 +70,30 @@ end
@testset "visualize contraction" begin
eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}())
nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
t_mp4 = viz_contraction(nested_code, pathname = "")
t_mp4 = viz_contraction(nested_code)
t_mp4_2 = viz_contraction(nested_code, filename = "test.mp4")
@test typeof(t_mp4) == String
t_gif = viz_contraction(nested_code, pathname = "", create_gif = true)
@test typeof(t_mp4_2) == String
t_gif = viz_contraction(nested_code, filename = "test.gif")
@test typeof(t_gif) == String

@test_throws ArgumentError begin
viz_contraction(nested_code, filename = "test.avi")
end

sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA())
t_mp4 = viz_contraction(sliced_code, pathname = "")
t_mp4 = viz_contraction(sliced_code)
t_mp4_2 = viz_contraction(sliced_code, filename = "test.mp4")
@test typeof(t_mp4) == String
t_gif = viz_contraction(sliced_code, pathname = "", create_gif = true)
@test typeof(t_mp4_2) == String
t_gif = viz_contraction(sliced_code, filename = "test.gif")
@test typeof(t_gif) == String

sliced_code2 = optimize_code(eincode, uniformsize(eincode, 2), TreeSA(nslices = 1))
t_mp4 = viz_contraction(sliced_code2, pathname = "")
t_mp4 = viz_contraction(sliced_code2)
t_mp4_2 = viz_contraction(sliced_code2, filename = "test.mp4")
@test typeof(t_mp4) == String
t_gif = viz_contraction(sliced_code2, pathname = "", create_gif = true)
@test typeof(t_mp4_2) == String
t_gif = viz_contraction(sliced_code2, filename = "test.gif")
@test typeof(t_gif) == String
end

0 comments on commit 4f1cd73

Please sign in to comment.