Skip to content

Commit

Permalink
clean categories_fusion_rule
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Jun 4, 2024
1 parent 9c968b0 commit bf31a2e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 44 deletions.
34 changes: 17 additions & 17 deletions NDTensors/src/lib/Sectors/src/category_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,26 +219,26 @@ end
function categories_fusion_rule(cats1::NamedTuple, cats2::NamedTuple)
diff_cat = CategoryProduct(symdiff_keys(cats1, cats2))
nt1 = intersect_keys(cats1, cats2)
shared1 = ntuple(i -> (; keys(nt1)[i] => values(nt1)[i]), length(nt1))
nt2 = intersect_keys(cats2, cats1)
shared2 = ntuple(i -> (; keys(nt2)[i] => values(nt2)[i]), length(nt2))
return diff_cat × categories_fusion_rule(shared1, shared2)
fused = map(fusion_rule, values(nt1), values(nt2))
return diff_cat × recover_key(typeof(nt1), fused)
end

# abelian fusion of one category
function fusion_rule(::AbelianGroup, cats1::NT, cats2::NT) where {NT<:NamedTuple}
fused = fusion_rule(only(values(cats1)), only(values(cats2)))
return sector(only(keys(cats1)) => fused)
function recover_key(NT::Type, fused::Tuple{Vararg{<:AbstractCategory}})
return sector(NT, fused)
end

# generic fusion of one category
function fusion_rule(::SymmetryStyle, cats1::NT, cats2::NT) where {NT<:NamedTuple}
fused = fusion_rule(only(values(cats1)), only(values(cats2)))
key = only(keys(cats1))
v = Vector{Pair{CategoryProduct{NT},Int64}}()
for la in BlockArrays.blocklengths(fused)
push!(v, sector(key => LabelledNumbers.label(la)) => LabelledNumbers.unlabel(la))
end
g = GradedAxes.gradedrange(v)
return g
function recover_key(NT::Type, fused::AbstractCategory)
return recover_key(NT, (fused,))
end

function recover_key(NT::Type, fused::CategoryProduct)
return recover_key(NT, categories(fused))
end

function recover_key(NT::Type, fused::Tuple)
g0 = reduce(×, fused)
blocklabels_key = recover_key.(NT, GradedAxes.blocklabels(g0))
pairs_key = blocklabels_key .=> LabelledNumbers.unlabel.(BlockArrays.blocklengths(g0))
return GradedAxes.gradedrange(pairs_key)
end
58 changes: 31 additions & 27 deletions NDTensors/src/lib/Sectors/test/test_category_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,17 @@ end
@testset "Fusion of different length Categories" begin
@test sector(U1(1) × U1(0)) sector(U1(1)) == sector(U1(2) × U1(0))
@test gradedisequal(
sector(SU2(0) × SU2(0)) sector(SU2(1)), gradedrange([sector(SU2(1) × SU2(0)) => 1])
(@inferred sector(SU2(0) × SU2(0)) sector(SU2(1))),
gradedrange([sector(SU2(1) × SU2(0)) => 1]),
)

@test gradedisequal(
sector(SU2(1) × U1(1)) sector(SU2(0)), gradedrange([sector(SU2(1) × U1(1)) => 1])
(@inferred sector(SU2(1) × U1(1)) sector(SU2(0))),
gradedrange([sector(SU2(1) × U1(1)) => 1]),
)
@test gradedisequal(
sector(U1(1) × SU2(1)) sector(U1(2)), gradedrange([sector(U1(3) × SU2(1)) => 1])
(@inferred sector(U1(1) × SU2(1)) sector(U1(2))),
gradedrange([sector(U1(3) × SU2(1)) => 1]),
)

# check incompatible categories
Expand Down Expand Up @@ -371,19 +374,19 @@ end
q01 = sector(; B=U1(1))
q11 = sector(; A=U1(1), B=U1(1))

@test q10 q10 == sector(; A=U1(2))
@test (@inferred q10 q10) == sector(; A=U1(2))
@test (@inferred q01 q00) == q01
@test (@inferred q00 q01) == q01
@test (@inferred q10 q01) == q11
@test q11 q11 == sector(; A=U1(2), B=U1(2))
@test (@inferred q11 q11) == sector(; A=U1(2), B=U1(2))

s11 = sector(; A=U1(1), B=Z{2}(1))
s10 = sector(; A=U1(1))
s01 = sector(; B=Z{2}(1))
@test (@inferred s01 q00) == s01
@test (@inferred q00 s01) == s01
@test (@inferred s10 s01) == s11
@test s11 s11 == sector(; A=U1(2), B=Z{2}(0))
@test (@inferred s11 s11) == sector(; A=U1(2), B=Z{2}(0))
end

@testset "Fusion of NonAbelian products" begin
Expand All @@ -393,14 +396,14 @@ end
phab = sector(; A=SU2(1//2), B=SU2(1//2))

@test gradedisequal(
pha pha, gradedrange([sector(; A=SU2(0)) => 1, sector(; A=SU2(1)) => 1])
(@inferred pha pha), gradedrange([sector(; A=SU2(0)) => 1, sector(; A=SU2(1)) => 1])
)
@test gradedisequal((@inferred pha p0), gradedrange([pha => 1]))
@test gradedisequal((@inferred p0 phb), gradedrange([phb => 1]))
@test gradedisequal((@inferred pha phb), gradedrange([phab => 1]))

@test gradedisequal(
phab phab,
(@inferred phab phab),
gradedrange([
sector(; A=SU2(0), B=SU2(0)) => 1,
sector(; A=SU2(1), B=SU2(0)) => 1,
Expand All @@ -414,11 +417,11 @@ end
ı = Fib("1")
τ = Fib("τ")
s = sector(; A=ı, B=ı)
@test gradedisequal(s s, gradedrange([s => 1]))
@test gradedisequal((@inferred s s), gradedrange([s => 1]))

s = sector(; A=τ, B=τ)
@test gradedisequal(
s s,
(@inferred s s),
gradedrange([
sector(; A=ı, B=ı) => 1,
sector(; A=τ, B=ı) => 1,
Expand All @@ -436,7 +439,7 @@ end
sector(; A=ı, B=ψ) => 1,
sector(; A=τ, B=ψ) => 1,
])
@test gradedisequal(s s, g)
@test gradedisequal((@inferred s s), g)
end

@testset "Fusion of mixed Abelian and NonAbelian products" begin
Expand All @@ -450,16 +453,16 @@ end
q21 = (N=U1(2),) × (J=SU2(1),)
q22 = (N=U1(2),) × (J=SU2(2),)

@test gradedisequal(q1h q1h, gradedrange([q20 => 1, q21 => 1]))
@test gradedisequal(q10 q1h, gradedrange([q2h => 1]))
@test gradedisequal(q0h q1h, gradedrange([q10 => 1, q11 => 1]))
@test gradedisequal(q11 q11, gradedrange([q20 => 1, q21 => 1, q22 => 1]))
@test gradedisequal((@inferred q1h q1h), gradedrange([q20 => 1, q21 => 1]))
@test gradedisequal((@inferred q10 q1h), gradedrange([q2h => 1]))
@test gradedisequal((@inferred q0h q1h), gradedrange([q10 => 1, q11 => 1]))
@test gradedisequal((@inferred q11 q11), gradedrange([q20 => 1, q21 => 1, q22 => 1]))
end

@testset "Fusion of fully mixed products" begin
s = sector(; A=U1(1), B=SU2(1//2), C=Ising("σ"))
@test gradedisequal(
s s,
(@inferred s s),
gradedrange([
sector(; A=U1(2), B=SU2(0), C=Ising("1")) => 1,
sector(; A=U1(2), B=SU2(1), C=Ising("1")) => 1,
Expand All @@ -472,7 +475,7 @@ end
τ = Fib("τ")
s = sector(; A=SU2(1//2), B=U1(1), C=τ)
@test gradedisequal(
s s,
(@inferred s s),
gradedrange([
sector(; A=SU2(0), B=U1(2), C=ı) => 1,
sector(; A=SU2(1), B=U1(2), C=ı) => 1,
Expand All @@ -483,7 +486,7 @@ end

s = sector(; A=τ, B=U1(1), C=ı)
@test gradedisequal(
s s,
(@inferred s s),
gradedrange([sector(; B=U1(2), A=ı, C=ı) => 1, sector(; B=U1(2), A=τ, C=ı) => 1]),
)
end
Expand All @@ -494,8 +497,9 @@ end
g2 = gradedrange([s2 => 1])
s3 = sector(; A=U1(1), B=SU2(0), C=Ising("σ"))
s4 = sector(; A=U1(1), B=SU2(1), C=Ising("σ"))
# type not inferred on julia 1.6 only
@test gradedisequal(fusion_product(g1, g2), gradedrange([s3 => 2, s4 => 2]))
@test gradedisequal(
(@inferred_latest fusion_product(g1, g2)), gradedrange([s3 => 2, s4 => 2])
)

sA = sector(; A=U1(1))
sB = sector(; B=SU2(1//2))
Expand All @@ -512,16 +516,16 @@ end
@test (@inferred s × s) == s
@test (@inferred s s) == s
@test (@inferred quantum_dimension(s)) == 1
@test trivial(s) == s # need julia 1.10 for type stability
@test (@inferred_latest trivial(s)) == s
@test typeof(s) == typeof(sector(()))
@test typeof(s) == typeof(sector((;))) # empty NamedTuple is cast to Tuple{}

@test s × U1(1) == sector(U1(1))
@test s × sector(U1(1)) == sector(U1(1))
@test s × sector(; A=U1(1)) == sector(; A=U1(1))
@test U1(1) × s == sector(U1(1))
@test sector(U1(1)) × s == sector(U1(1))
@test sector(; A=U1(1)) × s == sector(; A=U1(1))
@test (@inferred s × U1(1)) == sector(U1(1))
@test (@inferred s × sector(U1(1))) == sector(U1(1))
@test (@inferred s × sector(; A=U1(1))) == sector(; A=U1(1))
@test (@inferred U1(1) × s) == sector(U1(1))
@test (@inferred sector(U1(1)) × s) == sector(U1(1))
@test (@inferred sector(; A=U1(1)) × s) == sector(; A=U1(1))

# Empty acts as trivial
@test (@inferred U1(1) s) == U1(1)
Expand Down

0 comments on commit bf31a2e

Please sign in to comment.