From f6911eea7735e9ecbc39da8b21fc0ad64cea8bad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 7 Nov 2024 15:37:13 -0500 Subject: [PATCH] fix getindex and iterate --- .../lib/GradedAxes/src/labelledunitrangedual.jl | 9 +++++++++ NDTensors/src/lib/GradedAxes/test/test_dual.jl | 17 +++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl index 37a257ce94..466d64945b 100644 --- a/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl @@ -25,6 +25,15 @@ end # fix ambiguities Base.getindex(a::LabelledUnitRangeDual, i::Integer) = dual(nondual(a)[i]) +function Base.getindex(a::LabelledUnitRangeDual, indices::AbstractUnitRange{<:Integer}) + return dual(nondual(a)[indices]) +end + +function Base.iterate(a::LabelledUnitRangeDual, i) + i == last(a) && return nothing + next = convert(eltype(a), labelled(i + step(a), label(a))) + return (next, next) +end function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRangeDual) println(io, typeof(a)) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index c4f67c51a9..98b8838542 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -72,6 +72,11 @@ end @test space_isequal(la, la) @test label_type(la) == U1 + @test iterate(la) == (1, 1) + @test iterate(la) == (1, 1) + @test iterate(la, 1) == (2, 2) + @test isnothing(iterate(la, 2)) + lad = dual(la) @test lad isa LabelledUnitRangeDual @test label(lad) == U1(-1) @@ -87,10 +92,22 @@ end @test dual(lad) === la @test label_type(lad) == U1 + @test iterate(lad) == (1, 1) + @test iterate(lad) == (1, 1) + @test iterate(lad, 1) == (2, 2) + @test isnothing(iterate(lad, 2)) + + lad2 = lad[1:1] + @test lad2 isa LabelledUnitRangeDual + @test label(lad2) == U1(-1) + @test unlabel(lad2) == 1:1 + laf = flip(la) @test laf isa LabelledUnitRangeDual @test label(laf) == U1(1) @test unlabel(laf) == 1:2 + @test labelled_isequal(la, laf) + @test !space_isequal(la, laf) ladf = flip(dual(la)) @test ladf isa LabelledUnitRange