Skip to content

Commit

Permalink
refactor recover_sector_product_type
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Oct 11, 2024
1 parent 7407e83 commit 7a7f2b8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 26 deletions.
34 changes: 12 additions & 22 deletions NDTensors/src/lib/SymmetrySectors/src/sector_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,41 +122,31 @@ product_sectors_fusion_rule(::NamedTuple{()}, sects::Tuple) = SectorProduct(sect
product_sectors_fusion_rule(sects::NamedTuple, ::Tuple{}) = SectorProduct(sects)
product_sectors_fusion_rule(::Tuple{}, sects::NamedTuple) = SectorProduct(sects)

function fix_fused_product_type(T::Type, fused)
return fix_fused_product_type(product_sectors_symmetrystyle(T), T, fused)
function fix_fused_product_type(Sectors::Type, fused)
return fix_fused_product_type(product_sectors_symmetrystyle(Sectors), Sectors, fused)
end

function fix_fused_product_type(::AbelianStyle, T::Type, fused)
return recover_sector_product_type(T, fused)
function fix_fused_product_type(::AbelianStyle, Sectors::Type, fused)
return recover_sector_product_type(Sectors, fused)
end

function fix_fused_product_type(::NotAbelianStyle, T::Type, fused)
function fix_fused_product_type(::NotAbelianStyle, Sectors::Type, fused)
# convert e.g. Tuple{GradedUnitRange{SU2}, GradedUnitRange{SU2}} into GradedUnitRange{SU2×SU2}
g = reduce(×, fused)
# convention: keep unsorted blocklabels as produced by F order loops in ×
return recover_gradedaxis_product_type(T, g)
return recover_gradedaxis_product_type(Sectors, g)
end

function recover_gradedaxis_product_type(T::Type, g0::AbstractGradedUnitRange)
new_labels = recover_sector_product_type.(T, blocklabels(g0))
function recover_gradedaxis_product_type(Sectors::Type, g0::AbstractGradedUnitRange)
old_labels = blocklabels(g0)
old_sects = product_sectors.(SectorProduct.(old_labels))
new_labels = recover_sector_product_type.(Sectors, old_sects)
new_blocklengths = labelled.(unlabel.(blocklengths(g0)), new_labels)
return gradedrange(new_blocklengths)
end

function recover_sector_product_type(T::Type, c::AbstractSector)
return recover_sector_product_type(T, SectorProduct(c))
end

function recover_sector_product_type(T::Type, c::SectorProduct)
return recover_sector_product_type(T, product_sectors(c))
end

function recover_sector_product_type(T::Type{<:SectorProduct}, sects)
return recover_sector_product_type(product_sectors_type(T), sects)
end

function recover_sector_product_type(T::Type, sects)
return SectorProduct(T(sects))
function recover_sector_product_type(Sectors::Type, sects)
return SectorProduct(Sectors(sects))
end

# ================================= Cartesian Product ====================================
Expand Down
4 changes: 0 additions & 4 deletions NDTensors/src/lib/SymmetrySectors/test/test_sector_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ end
@test (@inferred_latest recover_sector_product_type(
typeof(product_sectors(s)), product_sectors(s)
)) == s
@test (@inferred_latest recover_sector_product_type(typeof(s), product_sectors(s))) == s

s = U1(3) × SU2(1//2) × Fib("τ")
@test length(product_sectors(s)) == 3
Expand Down Expand Up @@ -317,9 +316,6 @@ end
@test (@inferred_latest recover_sector_product_type(
typeof(product_sectors(s)), Tuple(product_sectors(s))
)) == s
@test (@inferred_latest recover_sector_product_type(
typeof(s), Tuple(product_sectors(s))
)) == s
@test s == (B=SU2(2),) × (A=U1(1),)

s = s × (C=Ising("ψ"),)
Expand Down

0 comments on commit 7a7f2b8

Please sign in to comment.