From 43c8b4d4c4bc5d5b212c08090d20f5b92f76918f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 14 Jun 2024 16:39:36 -0400 Subject: [PATCH] fix tensor_product(::dual) --- NDTensors/src/lib/GradedAxes/src/fusion.jl | 10 ++-- NDTensors/src/lib/GradedAxes/test/runtests.jl | 2 +- .../GradedAxes/test/test_tensor_product.jl | 49 ++++++++++++++++--- .../src/lib/Sectors/test/test_fusion_rules.jl | 8 +-- 4 files changed, 51 insertions(+), 18 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/fusion.jl b/NDTensors/src/lib/GradedAxes/src/fusion.jl index fdd6d03307..2e5513c7ee 100644 --- a/NDTensors/src/lib/GradedAxes/src/fusion.jl +++ b/NDTensors/src/lib/GradedAxes/src/fusion.jl @@ -44,15 +44,15 @@ end # Handle dual. Always return a non-dual GradedUnitRange. function tensor_product(a1::AbstractUnitRange, a2::UnitRangeDual) - return tensor_product(a1, flip(dual(a2))) + return tensor_product(a1, flip(a2)) end function tensor_product(a1::UnitRangeDual, a2::AbstractUnitRange) - return tensor_product(flip(dual(a1)), a2) + return tensor_product(flip(a1), a2) end function tensor_product(a1::UnitRangeDual, a2::UnitRangeDual) - return tensor_product(flip(dual(a1)), flip(dual(a2))) + return tensor_product(flip(a1), flip(a2)) end function fuse_labels(x, y) @@ -127,7 +127,7 @@ function blockmergesort(g::AbstractGradedUnitRange) return GradedAxes.gradedrange(new_blocklengths) end -blockmergesort(g::UnitRangeDual) = dual(blockmergesort(nondual(g))) +blockmergesort(g::UnitRangeDual) = dual(blockmergesort(flip(g))) blockmergesort(g::OneToOne) = g # fusion_product produces a sorted, non-dual GradedUnitRange @@ -136,7 +136,7 @@ function fusion_product(g1, g2) end fusion_product(g::AbstractUnitRange) = blockmergesort(g) -fusion_product(g::UnitRangeDual) = fusion_product(flip(nondual(g))) +fusion_product(g::UnitRangeDual) = fusion_product(flip(g)) # recursive fusion_product. Simpler than reduce + fix type stability issues with reduce function fusion_product(g1, g2, g3...) diff --git a/NDTensors/src/lib/GradedAxes/test/runtests.jl b/NDTensors/src/lib/GradedAxes/test/runtests.jl index 09335af5e8..c0fdca21be 100644 --- a/NDTensors/src/lib/GradedAxes/test/runtests.jl +++ b/NDTensors/src/lib/GradedAxes/test/runtests.jl @@ -2,7 +2,7 @@ using Test: @testset @testset "GradedAxes" begin include("test_basics.jl") - include("test_tensor_product.jl") include("test_dual.jl") + include("test_tensor_product.jl") end end diff --git a/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl b/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl index d99091508b..7b533f79c5 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl @@ -1,14 +1,26 @@ @eval module $(gensym()) +using Test: @test, @testset + +using BlockArrays: blocklength, blocklengths + using NDTensors.GradedAxes: GradedAxes, GradedOneTo, OneToOne, + dual, fusion_product, + flip, gradedrange, gradedisequal, + isdual, tensor_product -using BlockArrays: blocklength, blocklengths -using Test: @test, @testset + +struct U1 + n::Int +end +GradedAxes.dual(c::U1) = U1(-c.n) +Base.isless(c1::U1, c2::U1) = c1.n < c2.n +GradedAxes.fuse_labels(x::U1, y::U1) = U1(x.n + y.n) @testset "GradedAxes.tensor_product" begin GradedAxes.fuse_labels(x::String, y::String) = x * y @@ -31,20 +43,41 @@ using Test: @test, @testset end @testset "GradedAxes.fusion_product" begin - GradedAxes.fuse_labels(i::Int, j::Int) = i + j - g0 = OneToOne() @test gradedisequal(fusion_product(g0, g0), g0) - a = gradedrange([1 => 1, 2 => 3, 1 => 1]) + a = gradedrange([U1(1) => 1, U1(2) => 3, U1(1) => 1]) b = fusion_product(a) - @test gradedisequal(b, gradedrange([1 => 2, 2 => 3])) + @test gradedisequal(b, gradedrange([U1(1) => 2, U1(2) => 3])) c = fusion_product(a, a) - @test gradedisequal(c, gradedrange([2 => 4, 3 => 12, 4 => 9])) + @test gradedisequal(c, gradedrange([U1(2) => 4, U1(3) => 12, U1(4) => 9])) d = fusion_product(a, a, a) - @test gradedisequal(d, gradedrange([3 => 8, 4 => 36, 5 => 54, 6 => 27])) + @test gradedisequal(d, gradedrange([U1(3) => 8, U1(4) => 36, U1(5) => 54, U1(6) => 27])) +end + +@testset "dual and tensor_product" begin + a = gradedrange([U1(1) => 1, U1(2) => 3, U1(1) => 1]) + ad = dual(a) + + b = fusion_product(ad) + @test b isa GradedOneTo + @test !isdual(b) + @test gradedisequal(b, gradedrange([U1(-2) => 3, U1(-1) => 2])) + + c = fusion_product(ad, ad) + @test c isa GradedOneTo + @test !isdual(c) + @test gradedisequal(c, gradedrange([U1(-4) => 9, U1(-3) => 12, U1(-2) => 4])) + + d = fusion_product(ad, a) + @test !isdual(d) + @test gradedisequal(d, gradedrange([U1(-1) => 6, U1(0) => 13, U1(1) => 6])) + + e = fusion_product(a, ad) + @test !isdual(d) + @test gradedisequal(e, d) end end diff --git a/NDTensors/src/lib/Sectors/test/test_fusion_rules.jl b/NDTensors/src/lib/Sectors/test/test_fusion_rules.jl index 6e759dc347..c689bc7c3e 100644 --- a/NDTensors/src/lib/Sectors/test/test_fusion_rules.jl +++ b/NDTensors/src/lib/Sectors/test/test_fusion_rules.jl @@ -1,6 +1,6 @@ @eval module $(gensym()) using NDTensors.GradedAxes: - dual, fusion_product, gradedisequal, gradedrange, label_dual, tensor_product + dual, fusion_product, gradedisequal, gradedrange, flip, tensor_product using NDTensors.Sectors: ⊗, Fib, Ising, SU, SU2, U1, Z, block_dimensions, quantum_dimension, trivial using Test: @inferred, @test, @testset, @test_throws @@ -96,7 +96,7 @@ end g1 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 2]) g2 = gradedrange([U1(-2) => 2, U1(0) => 1, U1(1) => 2]) - @test gradedisequal(label_dual(g1), gradedrange([U1(1) => 1, U1(0) => 1, U1(-1) => 2])) + @test gradedisequal(flip(dual(g1)), gradedrange([U1(1) => 1, U1(0) => 1, U1(-1) => 2])) @test (@inferred block_dimensions(g1)) == [1, 1, 2] gt = gradedrange([ @@ -191,7 +191,7 @@ end @test gradedisequal(tensor_product(g3, g4), g34) - @test gradedisequal(label_dual(g3), g3) # trivial for SU(2) + @test gradedisequal(dual(flip(g3)), g3) # trivial for SU(2) @test gradedisequal( (@inferred fusion_product(g3, g4)), gradedrange([SU2(0) => 4, SU2(1//2) => 6, SU2(1) => 6, SU2(3//2) => 5, SU2(2) => 2]), @@ -206,7 +206,7 @@ end g5 = gradedrange([s1 => 1, f3 => 1]) g6 = gradedrange([s1 => 1, c3 => 1]) - @test gradedisequal(label_dual(g5), g6) + @test gradedisequal(dual(flip(g5)), g6) @test gradedisequal( fusion_product(g5, g6), gradedrange([s1 => 2, f3 => 1, c3 => 1, ad8 => 1]) )