Skip to content

Commit

Permalink
[GradedAxes] Check input and set convention for constructors (#1349)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe authored Mar 11, 2024
1 parent 9413240 commit 3d163c1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function tensor_product(
return fuse(isdual ? dual(l1) : l1, isdual ? dual(l2) : l2)
end,
)
return gradedrange(a, nondual_sectors_a, isdual)
return gradedrange(nondual_sectors_a, a, isdual)
end

function Base.show(io::IO, mimetype::MIME"text/plain", a::AbstractGradedUnitRange)
Expand Down
11 changes: 9 additions & 2 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BlockArrays: BlockArrays, Block, BlockRange, BlockedUnitRange, blockedrange
using BlockArrays:
BlockArrays, Block, BlockRange, BlockedUnitRange, blockedrange, blocklength

struct GradedUnitRange{T,S} <: AbstractGradedUnitRange{T,S}
blockedrange::BlockedUnitRange{T}
Expand All @@ -12,14 +13,20 @@ isdual(s::GradedUnitRange) = s.isdual
dual(s::GradedUnitRange) = GradedUnitRange(blockedrange(s), nondual_sectors(s), !isdual(s))

function gradedrange(nondual_sectors::Vector, blocklengths::Vector{Int}, isdual=false)
if length(nondual_sectors) != length(blocklengths)
throw(DomainError("Sector and block lengths do not match"))
end
return GradedUnitRange(blockedrange(blocklengths), nondual_sectors, isdual)
end

function gradedrange(sectors_lengths::Vector{<:Pair{<:Any,Int}}, isdual=false)
return gradedrange(first.(sectors_lengths), last.(sectors_lengths), isdual)
end

function gradedrange(a::BlockedUnitRange, nondual_sectors::Vector, isdual=false)
function gradedrange(nondual_sectors::Vector, a::BlockedUnitRange, isdual=false)
if length(nondual_sectors) != blocklength(a)
throw(DomainError("Number of sectors and number of blocks do not match"))
end
return GradedUnitRange(a, nondual_sectors, isdual)
end

Expand Down
9 changes: 7 additions & 2 deletions NDTensors/src/lib/GradedAxes/test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@eval module $(gensym())
using BlockArrays: Block, BlockVector, blocklength, blocklengths, findblock
using BlockArrays: Block, BlockVector, blockedrange, blocklength, blocklengths, findblock
using NDTensors.GradedAxes:
GradedAxes,
blockmergesortperm,
Expand All @@ -10,7 +10,7 @@ using NDTensors.GradedAxes:
sector,
sectors,
tensor_product
using Test: @test, @testset
using Test: @test, @testset, @test_throws

struct U1
dim::Int
Expand All @@ -23,6 +23,7 @@ GradedAxes.dual(l::U1) = U1(-l.dim)
a = gradedrange([U1(0), U1(1)], [2, 3])
@test a isa GradedAxes.GradedUnitRange
@test a == gradedrange([U1(0) => 2, U1(1) => 3])
@test a == gradedrange([U1(0), U1(1)], blockedrange([2, 3]))
@test length(a) == 5
@test a == 1:5
@test a[Block(1)] == 1:2
Expand All @@ -44,6 +45,10 @@ GradedAxes.dual(l::U1) = U1(-l.dim)
@test sector(a, 4) == U1(1)
@test sector(a, 5) == U1(1)

# test error for invalid input
@test_throws DomainError gradedrange([U1(0), U1(1)], [2, 3, 4])
@test_throws DomainError gradedrange([U1(0), U1(1)], blockedrange([2, 3, 4]))

# Naive tensor product, no sorting and merging
a = gradedrange([U1(0), U1(1)], [2, 3])
a2 = tensor_product(a, a)
Expand Down

0 comments on commit 3d163c1

Please sign in to comment.