diff --git a/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl index 1953230045..8782ddfd61 100644 --- a/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl @@ -81,7 +81,7 @@ function tensor_product( return fuse(isdual ? dual(l1) : l1, isdual ? dual(l2) : l2) end, ) - return gradedrange(a, nondual_sectors_a, isdual) + return gradedrange(nondual_sectors_a, a, isdual) end function Base.show(io::IO, mimetype::MIME"text/plain", a::AbstractGradedUnitRange) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index a8bbefc182..256ddf74b9 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -1,4 +1,5 @@ -using BlockArrays: BlockArrays, Block, BlockRange, BlockedUnitRange, blockedrange +using BlockArrays: + BlockArrays, Block, BlockRange, BlockedUnitRange, blockedrange, blocklength struct GradedUnitRange{T,S} <: AbstractGradedUnitRange{T,S} blockedrange::BlockedUnitRange{T} @@ -12,6 +13,9 @@ isdual(s::GradedUnitRange) = s.isdual dual(s::GradedUnitRange) = GradedUnitRange(blockedrange(s), nondual_sectors(s), !isdual(s)) function gradedrange(nondual_sectors::Vector, blocklengths::Vector{Int}, isdual=false) + if length(nondual_sectors) != length(blocklengths) + throw(DomainError("Sector and block lengths do not match")) + end return GradedUnitRange(blockedrange(blocklengths), nondual_sectors, isdual) end @@ -19,7 +23,10 @@ function gradedrange(sectors_lengths::Vector{<:Pair{<:Any,Int}}, isdual=false) return gradedrange(first.(sectors_lengths), last.(sectors_lengths), isdual) end -function gradedrange(a::BlockedUnitRange, nondual_sectors::Vector, isdual=false) +function gradedrange(nondual_sectors::Vector, a::BlockedUnitRange, isdual=false) + if length(nondual_sectors) != blocklength(a) + throw(DomainError("Number of sectors and number of blocks do not match")) + end return GradedUnitRange(a, nondual_sectors, isdual) end diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 9be177b2b6..0c612c684d 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -1,5 +1,5 @@ @eval module $(gensym()) -using BlockArrays: Block, BlockVector, blocklength, blocklengths, findblock +using BlockArrays: Block, BlockVector, blockedrange, blocklength, blocklengths, findblock using NDTensors.GradedAxes: GradedAxes, blockmergesortperm, @@ -10,7 +10,7 @@ using NDTensors.GradedAxes: sector, sectors, tensor_product -using Test: @test, @testset +using Test: @test, @testset, @test_throws struct U1 dim::Int @@ -23,6 +23,7 @@ GradedAxes.dual(l::U1) = U1(-l.dim) a = gradedrange([U1(0), U1(1)], [2, 3]) @test a isa GradedAxes.GradedUnitRange @test a == gradedrange([U1(0) => 2, U1(1) => 3]) + @test a == gradedrange([U1(0), U1(1)], blockedrange([2, 3])) @test length(a) == 5 @test a == 1:5 @test a[Block(1)] == 1:2 @@ -44,6 +45,10 @@ GradedAxes.dual(l::U1) = U1(-l.dim) @test sector(a, 4) == U1(1) @test sector(a, 5) == U1(1) + # test error for invalid input + @test_throws DomainError gradedrange([U1(0), U1(1)], [2, 3, 4]) + @test_throws DomainError gradedrange([U1(0), U1(1)], blockedrange([2, 3, 4])) + # Naive tensor product, no sorting and merging a = gradedrange([U1(0), U1(1)], [2, 3]) a2 = tensor_product(a, a)