From e716903699456b41e3e0355665c051573c99eacf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 31 Oct 2024 19:41:29 -0400 Subject: [PATCH] use GradedAxes interface --- .../lib/GradedAxes/src/gradedunitrangedual.jl | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index def3dbfc75..bc0824c098 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -17,26 +17,28 @@ Base.step(a::GradedUnitRangeDual) = label_dual(step(nondual(a))) Base.view(a::GradedUnitRangeDual, index::Block{1}) = a[index] -function Base.getindex(a::GradedUnitRangeDual, indices::AbstractUnitRange{<:Integer}) +function gradedunitrange_getindices( + a::GradedUnitRangeDual, indices::AbstractUnitRange{<:Integer} +) return dual(getindex(nondual(a), indices)) end using BlockArrays: Block, BlockIndexRange, BlockRange -function Base.getindex(a::GradedUnitRangeDual, indices::Integer) +function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer) return label_dual(getindex(nondual(a), indices)) end -function Base.getindex(a::GradedUnitRangeDual, indices::Block{1}) +function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1}) return label_dual(getindex(nondual(a), indices)) end -function Base.getindex(a::GradedUnitRangeDual, indices::BlockRange) +function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange) return label_dual(getindex(nondual(a), indices)) end # fix ambiguity -function Base.getindex( +function gradedunitrange_getindices( a::GradedUnitRangeDual, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} ) return dual(getindex(nondual(a), indices)) @@ -52,15 +54,13 @@ function unitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices) end # TODO: Move this to a `BlockArraysExtensions` library. -function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1}) - return a[indices] -end - -function Base.getindex(a::GradedUnitRangeDual, indices::Vector{<:Block{1}}) +function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}}) return unitrangedual_getindices_blocks(a, indices) end -function Base.getindex(a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}) +function gradedunitrange_getindices( + a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}} +) return unitrangedual_getindices_blocks(a, indices) end @@ -79,16 +79,10 @@ function BlockArrays.BlockSlice(b::Block, r::GradedUnitRangeDual) end using NDTensors.LabelledNumbers: LabelledNumbers, LabelledUnitRange, label -# The Base version of `length(::AbstractUnitRange)` drops the label. -function Base.length(a::GradedUnitRangeDual{<:Any,<:LabelledUnitRange}) - return dual(length(nondual(a))) -end function Base.iterate(a::GradedUnitRangeDual, i) i == last(a) && return nothing return dual.(iterate(nondual(a), i)) end -# TODO: Is this a good definition? -Base.unitrange(a::GradedUnitRangeDual) = a using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock