Skip to content

Commit

Permalink
rename sectors product_sectors
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Oct 8, 2024
1 parent 2fe8bbc commit ddb62d5
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 110 deletions.
157 changes: 82 additions & 75 deletions NDTensors/src/lib/SymmetrySectors/src/sector_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,40 @@ using ..GradedAxes: AbstractGradedUnitRange, GradedAxes, dual

# ===================================== Definition =======================================
struct SectorProduct{Sectors} <: AbstractSector
sectors::Sectors
product_sectors::Sectors
global _SectorProduct(l) = new{typeof(l)}(l)
end

SectorProduct(c::SectorProduct) = _SectorProduct(sectors(c))
SectorProduct(c::SectorProduct) = _SectorProduct(product_sectors(c))

sectors(s::SectorProduct) = s.sectors
product_sectors(s::SectorProduct) = s.product_sectors

# ================================= Sectors interface ====================================
SymmetryStyle(T::Type{<:SectorProduct}) = sectors_symmetrystyle(sectors_type(T))
function SymmetryStyle(T::Type{<:SectorProduct})
return product_sectors_symmetrystyle(product_sectors_type(T))
end

function quantum_dimension(::NotAbelianStyle, s::SectorProduct)
return mapreduce(quantum_dimension, *, sectors(s))
return mapreduce(quantum_dimension, *, product_sectors(s))
end

# use map instead of broadcast to support both Tuple and NamedTuple
GradedAxes.dual(s::SectorProduct) = SectorProduct(map(dual, sectors(s)))
GradedAxes.dual(s::SectorProduct) = SectorProduct(map(dual, product_sectors(s)))

function trivial(type::Type{<:SectorProduct})
return SectorProduct(sectors_trivial(sectors_type(type)))
return SectorProduct(product_sectors_trivial(product_sectors_type(type)))
end

# =================================== Base interface =====================================
function Base.:(==)(A::SectorProduct, B::SectorProduct)
return sectors_isequal(sectors(A), sectors(B))
return product_sectors_isequal(product_sectors(A), product_sectors(B))
end

function Base.show(io::IO, s::SectorProduct)
(length(sectors(s)) < 2) && print(io, "sector")
(length(product_sectors(s)) < 2) && print(io, "sector")
print(io, "(")
symbol = ""
for p in pairs(sectors(s))
for p in pairs(product_sectors(s))
print(io, symbol)
sector_show(io, p[1], p[2])
symbol = " × "
Expand All @@ -50,75 +52,78 @@ sector_show(io::IO, ::Int, v) = print(io, v)
sector_show(io::IO, k::Symbol, v) = print(io, "($k=$v,)")

function Base.isless(s1::SectorProduct, s2::SectorProduct)
return sectors_isless(sectors(s1), sectors(s2))
return product_sectors_isless(product_sectors(s1), product_sectors(s2))
end

# ======================================= shared =========================================
# there are 2 implementations for SectorProduct
# - ordered-like with a Tuple
# - dictionary-like with a NamedTuple

function sym_sectors_insert_unspecified(s1, s2)
return sectors_insert_unspecified(s1, s2), sectors_insert_unspecified(s2, s1)
function sym_product_sectors_insert_unspecified(s1, s2)
return product_sectors_insert_unspecified(s1, s2),
product_sectors_insert_unspecified(s2, s1)
end

function sectors_isequal(s1, s2)
return ==(sym_sectors_insert_unspecified(s1, s2)...)
function product_sectors_isequal(s1, s2)
return ==(sym_product_sectors_insert_unspecified(s1, s2)...)
end

# get clean results when mixing implementations
function sectors_isequal(nt::NamedTuple, ::Tuple{})
return sectors_isequal(nt, (;))
function product_sectors_isequal(nt::NamedTuple, ::Tuple{})
return product_sectors_isequal(nt, (;))
end
function sectors_isequal(::Tuple{}, nt::NamedTuple)
return sectors_isequal((;), nt)
function product_sectors_isequal(::Tuple{}, nt::NamedTuple)
return product_sectors_isequal((;), nt)
end
function sectors_isequal(::NamedTuple{()}, t::Tuple)
return sectors_isequal((), t)
function product_sectors_isequal(::NamedTuple{()}, t::Tuple)
return product_sectors_isequal((), t)
end
function sectors_isequal(t::Tuple, ::NamedTuple{()})
return sectors_isequal(t, ())
function product_sectors_isequal(t::Tuple, ::NamedTuple{()})
return product_sectors_isequal(t, ())
end
sectors_isequal(::Tuple{}, ::NamedTuple{()}) = true
sectors_isequal(::NamedTuple{()}, ::Tuple{}) = true
sectors_isequal(::Tuple, ::NamedTuple) = false
sectors_isequal(::NamedTuple, ::Tuple) = false
product_sectors_isequal(::Tuple{}, ::NamedTuple{()}) = true
product_sectors_isequal(::NamedTuple{()}, ::Tuple{}) = true
product_sectors_isequal(::Tuple, ::NamedTuple) = false
product_sectors_isequal(::NamedTuple, ::Tuple) = false

function sectors_isless(nt::NamedTuple, ::Tuple{})
return sectors_isless(nt, (;))
function product_sectors_isless(nt::NamedTuple, ::Tuple{})
return product_sectors_isless(nt, (;))
end
function sectors_isless(::Tuple{}, nt::NamedTuple)
return sectors_isless((;), nt)
function product_sectors_isless(::Tuple{}, nt::NamedTuple)
return product_sectors_isless((;), nt)
end
function sectors_isless(::NamedTuple{()}, t::Tuple)
return sectors_isless((), t)
function product_sectors_isless(::NamedTuple{()}, t::Tuple)
return product_sectors_isless((), t)
end
function sectors_isless(t::Tuple, ::NamedTuple{()})
return sectors_isless(t, ())
function product_sectors_isless(t::Tuple, ::NamedTuple{()})
return product_sectors_isless(t, ())
end
function sectors_isless(s1, s2)
return isless(sym_sectors_insert_unspecified(s1, s2)...)
function product_sectors_isless(s1, s2)
return isless(sym_product_sectors_insert_unspecified(s1, s2)...)
end

sectors_isless(::NamedTuple, ::Tuple) = throw(ArgumentError("Not implemented"))
sectors_isless(::Tuple, ::NamedTuple) = throw(ArgumentError("Not implemented"))
product_sectors_isless(::NamedTuple, ::Tuple) = throw(ArgumentError("Not implemented"))
product_sectors_isless(::Tuple, ::NamedTuple) = throw(ArgumentError("Not implemented"))

sectors_type(::Type{<:SectorProduct{T}}) where {T} = T
product_sectors_type(::Type{<:SectorProduct{T}}) where {T} = T

function sectors_fusion_rule(sects1, sects2)
shared_sect = shared_sectors_fusion_rule(sectors_common(sects1, sects2)...)
diff_sect = SectorProduct(sectors_diff(sects1, sects2))
function product_sectors_fusion_rule(sects1, sects2)
shared_sect = shared_product_sectors_fusion_rule(
product_sectors_common(sects1, sects2)...
)
diff_sect = SectorProduct(product_sectors_diff(sects1, sects2))
return shared_sect × diff_sect
end

# edge case with empty sectors
sectors_fusion_rule(sects::Tuple, ::NamedTuple{()}) = SectorProduct(sects)
sectors_fusion_rule(::NamedTuple{()}, sects::Tuple) = SectorProduct(sects)
sectors_fusion_rule(sects::NamedTuple, ::Tuple{}) = SectorProduct(sects)
sectors_fusion_rule(::Tuple{}, sects::NamedTuple) = SectorProduct(sects)
# edge case with empty product_sectors
product_sectors_fusion_rule(sects::Tuple, ::NamedTuple{()}) = SectorProduct(sects)
product_sectors_fusion_rule(::NamedTuple{()}, sects::Tuple) = SectorProduct(sects)
product_sectors_fusion_rule(sects::NamedTuple, ::Tuple{}) = SectorProduct(sects)
product_sectors_fusion_rule(::Tuple{}, sects::NamedTuple) = SectorProduct(sects)

function recover_style(T::Type, fused)
style = sectors_symmetrystyle(T)
style = product_sectors_symmetrystyle(T)
return recover_sector_product_type(style, T, fused)
end

Expand Down Expand Up @@ -146,11 +151,11 @@ function recover_sector_product_type(T::Type, c::AbstractSector)
end

function recover_sector_product_type(T::Type, c::SectorProduct)
return recover_sector_product_type(T, sectors(c))
return recover_sector_product_type(T, product_sectors(c))
end

function recover_sector_product_type(T::Type{<:SectorProduct}, sects)
return recover_sector_product_type(sectors_type(T), sects)
return recover_sector_product_type(product_sectors_type(T), sects)
end

function recover_sector_product_type(T::Type, sects)
Expand All @@ -160,7 +165,7 @@ end
# ================================= Cartesian Product ====================================
×(c1::AbstractSector, c2::AbstractSector) = ×(SectorProduct(c1), SectorProduct(c2))
function ×(p1::SectorProduct, p2::SectorProduct)
return SectorProduct(sectors_product(sectors(p1), sectors(p2)))
return SectorProduct(product_sectors_product(product_sectors(p1), product_sectors(p2)))
end

×(a, g::AbstractUnitRange) = ×(to_gradedrange(a), g)
Expand Down Expand Up @@ -194,12 +199,14 @@ end

# generic case: fusion returns a GradedAxes, even for fusion with Empty
function fusion_rule(::NotAbelianStyle, s1::SectorProduct, s2::SectorProduct)
return to_gradedrange(sectors_fusion_rule(sectors(s1), sectors(s2)))
return to_gradedrange(
product_sectors_fusion_rule(product_sectors(s1), product_sectors(s2))
)
end

# Abelian case: fusion returns SectorProduct
function fusion_rule(::AbelianStyle, s1::SectorProduct, s2::SectorProduct)
return sectors_fusion_rule(sectors(s1), sectors(s2))
return product_sectors_fusion_rule(product_sectors(s1), product_sectors(s2))
end

# lift ambiguities for TrivialSector
Expand All @@ -212,41 +219,41 @@ fusion_rule(::NotAbelianStyle, ::TrivialSector, c::SectorProduct) = to_gradedran
SectorProduct(t::Tuple) = _SectorProduct(t)
SectorProduct(sects::AbstractSector...) = SectorProduct(sects)

function sectors_symmetrystyle(T::Type{<:Tuple})
function product_sectors_symmetrystyle(T::Type{<:Tuple})
return mapreduce(SymmetryStyle, combine_styles, fieldtypes(T); init=AbelianStyle())
end

sectors_product(::NamedTuple{()}, l1::Tuple) = l1
sectors_product(l2::Tuple, ::NamedTuple{()}) = l2
sectors_product(l1::Tuple, l2::Tuple) = (l1..., l2...)
product_sectors_product(::NamedTuple{()}, l1::Tuple) = l1
product_sectors_product(l2::Tuple, ::NamedTuple{()}) = l2
product_sectors_product(l1::Tuple, l2::Tuple) = (l1..., l2...)

sectors_trivial(type::Type{<:Tuple}) = trivial.(fieldtypes(type))
product_sectors_trivial(type::Type{<:Tuple}) = trivial.(fieldtypes(type))

function sectors_common(t1::Tuple, t2::Tuple)
function product_sectors_common(t1::Tuple, t2::Tuple)
n = min(length(t1), length(t2))
return t1[begin:n], t2[begin:n]
end

function sectors_diff(t1::Tuple, t2::Tuple)
function product_sectors_diff(t1::Tuple, t2::Tuple)
n1 = length(t1)
n2 = length(t2)
return n1 < n2 ? t2[(n1 + 1):end] : t1[(n2 + 1):end]
end

function shared_sectors_fusion_rule(shared1::T, shared2::T) where {T<:Tuple}
function shared_product_sectors_fusion_rule(shared1::T, shared2::T) where {T<:Tuple}
fused = map(fusion_rule, shared1, shared2)
return recover_style(T, fused)
end

function sectors_insert_unspecified(t1::Tuple, t2::Tuple)
function product_sectors_insert_unspecified(t1::Tuple, t2::Tuple)
n1 = length(t1)
return (t1..., trivial.(t2[(n1 + 1):end])...)
end

# =========================== Dictionary-like implementation =============================
function SectorProduct(nt::NamedTuple)
sectors = sort_keys(nt)
return _SectorProduct(sectors)
product_sectors = sort_keys(nt)
return _SectorProduct(product_sectors)
end

SectorProduct(; kws...) = SectorProduct((; kws...))
Expand All @@ -257,38 +264,38 @@ function SectorProduct(pairs::Pair...)
return SectorProduct(NamedTuple{keys}(vals))
end

function sectors_symmetrystyle(NT::Type{<:NamedTuple})
function product_sectors_symmetrystyle(NT::Type{<:NamedTuple})
return mapreduce(SymmetryStyle, combine_styles, fieldtypes(NT); init=AbelianStyle())
end

function sectors_insert_unspecified(nt1::NamedTuple, nt2::NamedTuple)
diff1 = sectors_trivial(typeof(setdiff_keys(nt2, nt1)))
function product_sectors_insert_unspecified(nt1::NamedTuple, nt2::NamedTuple)
diff1 = product_sectors_trivial(typeof(setdiff_keys(nt2, nt1)))
return sort_keys(union_keys(nt1, diff1))
end

sectors_product(l1::NamedTuple, ::Tuple{}) = l1
sectors_product(::Tuple{}, l2::NamedTuple) = l2
function sectors_product(l1::NamedTuple, l2::NamedTuple)
product_sectors_product(l1::NamedTuple, ::Tuple{}) = l1
product_sectors_product(::Tuple{}, l2::NamedTuple) = l2
function product_sectors_product(l1::NamedTuple, l2::NamedTuple)
if length(intersect_keys(l1, l2)) > 0
throw(ArgumentError("Cannot define product of shared keys"))
end
return union_keys(l1, l2)
end

function sectors_trivial(type::Type{<:NamedTuple{Keys}}) where {Keys}
function product_sectors_trivial(type::Type{<:NamedTuple{Keys}}) where {Keys}
return NamedTuple{Keys}(trivial.(fieldtypes(type)))
end

function sectors_common(nt1::NamedTuple, nt2::NamedTuple)
function product_sectors_common(nt1::NamedTuple, nt2::NamedTuple)
# SectorProduct(nt::NamedTuple) sorts keys at init
@assert issorted(keys(nt1))
@assert issorted(keys(nt2))
return intersect_keys(nt1, nt2), intersect_keys(nt2, nt1)
end

sectors_diff(nt1::NamedTuple, nt2::NamedTuple) = symdiff_keys(nt1, nt2)
product_sectors_diff(nt1::NamedTuple, nt2::NamedTuple) = symdiff_keys(nt1, nt2)

function shared_sectors_fusion_rule(shared1::T, shared2::T) where {T<:NamedTuple}
function shared_product_sectors_fusion_rule(shared1::T, shared2::T) where {T<:NamedTuple}
fused = map(fusion_rule, values(shared1), values(shared2))
return recover_style(T, fused)
end
Loading

0 comments on commit ddb62d5

Please sign in to comment.