diff --git a/NDTensors/src/lib/GradedAxes/src/dual.jl b/NDTensors/src/lib/GradedAxes/src/dual.jl index 28f140d69e..ca985e30a0 100644 --- a/NDTensors/src/lib/GradedAxes/src/dual.jl +++ b/NDTensors/src/lib/GradedAxes/src/dual.jl @@ -5,6 +5,8 @@ isdual(::AbstractUnitRange) = false using NDTensors.LabelledNumbers: LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel + +dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i))) label_dual(x) = label_dual(LabelledStyle(x), x) label_dual(::NotLabelled, x) = x label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x))) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index 1ffee324fc..c6d79495a5 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -104,8 +104,6 @@ end Base.unitrange(a::GradedUnitRangeDual) = a using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel -dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i))) - using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock BlockArrays.blockaxes(a::GradedUnitRangeDual) = blockaxes(nondual(a)) BlockArrays.blockfirsts(a::GradedUnitRangeDual) = label_dual.(blockfirsts(nondual(a))) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 04203961a7..714dd04b7a 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -12,6 +12,7 @@ using BlockArrays: blocks, findblock using NDTensors.GradedAxes: + AbstractGradedUnitRange, GradedAxes, GradedUnitRangeDual, OneToOne, @@ -60,6 +61,7 @@ end [gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]] ad = dual(a) @test ad isa GradedUnitRangeDual + @test ad isa AbstractGradedUnitRange @test eltype(ad) == LabelledInteger{Int,U1} @test blocklengths(ad) isa Vector @test eltype(blocklengths(ad)) == eltype(blocklengths(a)) @@ -78,6 +80,8 @@ end @test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))] @test blocklength(ad) == 2 @test blocklengths(ad) == [2, 3] + @test blocklabels(ad) == [U1(0), U1(-1)] + @test label.(blocklengths(ad)) == [U1(0), U1(-1)] @test findblock(ad, 4) == Block(2) @test only(blockaxes(ad)) == Block(1):Block(2) @test blocks(ad) == [labelled(1:2, U1(0)), labelled(3:5, U1(-1))]