Skip to content

Commit

Permalink
fix tensor_product(::dual)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Jun 14, 2024
1 parent 45e8cc6 commit 43c8b4d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 18 deletions.
10 changes: 5 additions & 5 deletions NDTensors/src/lib/GradedAxes/src/fusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ end

# Handle dual. Always return a non-dual GradedUnitRange.
function tensor_product(a1::AbstractUnitRange, a2::UnitRangeDual)
return tensor_product(a1, flip(dual(a2)))
return tensor_product(a1, flip(a2))
end

function tensor_product(a1::UnitRangeDual, a2::AbstractUnitRange)
return tensor_product(flip(dual(a1)), a2)
return tensor_product(flip(a1), a2)
end

function tensor_product(a1::UnitRangeDual, a2::UnitRangeDual)
return tensor_product(flip(dual(a1)), flip(dual(a2)))
return tensor_product(flip(a1), flip(a2))
end

function fuse_labels(x, y)
Expand Down Expand Up @@ -127,7 +127,7 @@ function blockmergesort(g::AbstractGradedUnitRange)
return GradedAxes.gradedrange(new_blocklengths)
end

blockmergesort(g::UnitRangeDual) = dual(blockmergesort(nondual(g)))
blockmergesort(g::UnitRangeDual) = dual(blockmergesort(flip(g)))
blockmergesort(g::OneToOne) = g

# fusion_product produces a sorted, non-dual GradedUnitRange
Expand All @@ -136,7 +136,7 @@ function fusion_product(g1, g2)
end

fusion_product(g::AbstractUnitRange) = blockmergesort(g)
fusion_product(g::UnitRangeDual) = fusion_product(flip(nondual(g)))
fusion_product(g::UnitRangeDual) = fusion_product(flip(g))

# recursive fusion_product. Simpler than reduce + fix type stability issues with reduce
function fusion_product(g1, g2, g3...)
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/lib/GradedAxes/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using Test: @testset
@testset "GradedAxes" begin
include("test_basics.jl")
include("test_tensor_product.jl")
include("test_dual.jl")
include("test_tensor_product.jl")
end
end
49 changes: 41 additions & 8 deletions NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
@eval module $(gensym())
using Test: @test, @testset

using BlockArrays: blocklength, blocklengths

using NDTensors.GradedAxes:
GradedAxes,
GradedOneTo,
OneToOne,
dual,
fusion_product,
flip,
gradedrange,
gradedisequal,
isdual,
tensor_product
using BlockArrays: blocklength, blocklengths
using Test: @test, @testset

struct U1
n::Int
end
GradedAxes.dual(c::U1) = U1(-c.n)
Base.isless(c1::U1, c2::U1) = c1.n < c2.n
GradedAxes.fuse_labels(x::U1, y::U1) = U1(x.n + y.n)

@testset "GradedAxes.tensor_product" begin
GradedAxes.fuse_labels(x::String, y::String) = x * y
Expand All @@ -31,20 +43,41 @@ using Test: @test, @testset
end

@testset "GradedAxes.fusion_product" begin
GradedAxes.fuse_labels(i::Int, j::Int) = i + j

g0 = OneToOne()
@test gradedisequal(fusion_product(g0, g0), g0)

a = gradedrange([1 => 1, 2 => 3, 1 => 1])
a = gradedrange([U1(1) => 1, U1(2) => 3, U1(1) => 1])

b = fusion_product(a)
@test gradedisequal(b, gradedrange([1 => 2, 2 => 3]))
@test gradedisequal(b, gradedrange([U1(1) => 2, U1(2) => 3]))

c = fusion_product(a, a)
@test gradedisequal(c, gradedrange([2 => 4, 3 => 12, 4 => 9]))
@test gradedisequal(c, gradedrange([U1(2) => 4, U1(3) => 12, U1(4) => 9]))

d = fusion_product(a, a, a)
@test gradedisequal(d, gradedrange([3 => 8, 4 => 36, 5 => 54, 6 => 27]))
@test gradedisequal(d, gradedrange([U1(3) => 8, U1(4) => 36, U1(5) => 54, U1(6) => 27]))
end

@testset "dual and tensor_product" begin
a = gradedrange([U1(1) => 1, U1(2) => 3, U1(1) => 1])
ad = dual(a)

b = fusion_product(ad)
@test b isa GradedOneTo
@test !isdual(b)
@test gradedisequal(b, gradedrange([U1(-2) => 3, U1(-1) => 2]))

c = fusion_product(ad, ad)
@test c isa GradedOneTo
@test !isdual(c)
@test gradedisequal(c, gradedrange([U1(-4) => 9, U1(-3) => 12, U1(-2) => 4]))

d = fusion_product(ad, a)
@test !isdual(d)
@test gradedisequal(d, gradedrange([U1(-1) => 6, U1(0) => 13, U1(1) => 6]))

e = fusion_product(a, ad)
@test !isdual(d)
@test gradedisequal(e, d)
end
end
8 changes: 4 additions & 4 deletions NDTensors/src/lib/Sectors/test/test_fusion_rules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@eval module $(gensym())
using NDTensors.GradedAxes:
dual, fusion_product, gradedisequal, gradedrange, label_dual, tensor_product
dual, fusion_product, gradedisequal, gradedrange, flip, tensor_product
using NDTensors.Sectors:
, Fib, Ising, SU, SU2, U1, Z, block_dimensions, quantum_dimension, trivial
using Test: @inferred, @test, @testset, @test_throws
Expand Down Expand Up @@ -96,7 +96,7 @@ end
g1 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 2])
g2 = gradedrange([U1(-2) => 2, U1(0) => 1, U1(1) => 2])

@test gradedisequal(label_dual(g1), gradedrange([U1(1) => 1, U1(0) => 1, U1(-1) => 2]))
@test gradedisequal(flip(dual(g1)), gradedrange([U1(1) => 1, U1(0) => 1, U1(-1) => 2]))
@test (@inferred block_dimensions(g1)) == [1, 1, 2]

gt = gradedrange([
Expand Down Expand Up @@ -191,7 +191,7 @@ end

@test gradedisequal(tensor_product(g3, g4), g34)

@test gradedisequal(label_dual(g3), g3) # trivial for SU(2)
@test gradedisequal(dual(flip(g3)), g3) # trivial for SU(2)
@test gradedisequal(
(@inferred fusion_product(g3, g4)),
gradedrange([SU2(0) => 4, SU2(1//2) => 6, SU2(1) => 6, SU2(3//2) => 5, SU2(2) => 2]),
Expand All @@ -206,7 +206,7 @@ end

g5 = gradedrange([s1 => 1, f3 => 1])
g6 = gradedrange([s1 => 1, c3 => 1])
@test gradedisequal(label_dual(g5), g6)
@test gradedisequal(dual(flip(g5)), g6)
@test gradedisequal(
fusion_product(g5, g6), gradedrange([s1 => 2, f3 => 1, c3 => 1, ad8 => 1])
)
Expand Down

0 comments on commit 43c8b4d

Please sign in to comment.