diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 0bd35707a7..9197a99fe8 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -68,6 +68,7 @@ end # == is just a range comparison that ignores labels. Need dedicated function to check equality. struct NoLabel end blocklabels(r::AbstractUnitRange) = Fill(NoLabel(), blocklength(r)) +blocklabels(la::LabelledUnitRange) = label(la) function LabelledNumbers.labelled_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange) return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2)) diff --git a/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl index 256cd3f77d..0dd13ddb5a 100644 --- a/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl @@ -12,6 +12,7 @@ nondual(a::LabelledUnitRangeDual) = a.nondual_unitrange dual(a::LabelledUnitRangeDual) = nondual(a) flip(a::LabelledUnitRangeDual) = dual(flip(nondual(a))) isdual(::LabelledUnitRangeDual) = true +blocklabels(la::LabelledUnitRangeDual) = label(la) LabelledNumbers.label(a::LabelledUnitRangeDual) = dual(label(nondual(a))) LabelledNumbers.unlabel(a::LabelledUnitRangeDual) = unlabel(nondual(a)) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 8ebf75a520..17c03c8edd 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -73,6 +73,21 @@ end @test label(lad) == U1(-1) @test unlabel(lad) == 1:2 @test lad == 1:2 + @test !labelled_isequal(la, lad) + @test !space_isequal(la, lad) + @test isdual(lad) + @test nondual(lad) === la + @test dual(lad) === la + + # check default behavior for objects without dual + la = labelled(1:2, 'x') + lad = dual(la) + @test lad isa LabelledUnitRangeDual + @test label(lad) == 'x' + @test unlabel(lad) == 1:2 + @test lad == 1:2 + @test labelled_isequal(la, lad) + @test !space_isequal(la, lad) @test isdual(lad) @test nondual(lad) === la @test dual(lad) === la