Skip to content

Commit

Permalink
use GradedAxes interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Oct 31, 2024
1 parent 032bde4 commit e716903
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,28 @@ Base.step(a::GradedUnitRangeDual) = label_dual(step(nondual(a)))

Base.view(a::GradedUnitRangeDual, index::Block{1}) = a[index]

function Base.getindex(a::GradedUnitRangeDual, indices::AbstractUnitRange{<:Integer})
function gradedunitrange_getindices(
a::GradedUnitRangeDual, indices::AbstractUnitRange{<:Integer}
)
return dual(getindex(nondual(a), indices))
end

using BlockArrays: Block, BlockIndexRange, BlockRange

function Base.getindex(a::GradedUnitRangeDual, indices::Integer)
function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer)
return label_dual(getindex(nondual(a), indices))
end

function Base.getindex(a::GradedUnitRangeDual, indices::Block{1})
function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1})
return label_dual(getindex(nondual(a), indices))
end

function Base.getindex(a::GradedUnitRangeDual, indices::BlockRange)
function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange)
return label_dual(getindex(nondual(a), indices))
end

# fix ambiguity
function Base.getindex(
function gradedunitrange_getindices(
a::GradedUnitRangeDual, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}}
)
return dual(getindex(nondual(a), indices))
Expand All @@ -52,15 +54,13 @@ function unitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices)
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1})
return a[indices]
end

function Base.getindex(a::GradedUnitRangeDual, indices::Vector{<:Block{1}})
function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}})
return unitrangedual_getindices_blocks(a, indices)
end

function Base.getindex(a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}})
function gradedunitrange_getindices(
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
)
return unitrangedual_getindices_blocks(a, indices)
end

Expand All @@ -79,16 +79,10 @@ function BlockArrays.BlockSlice(b::Block, r::GradedUnitRangeDual)
end

using NDTensors.LabelledNumbers: LabelledNumbers, LabelledUnitRange, label
# The Base version of `length(::AbstractUnitRange)` drops the label.
function Base.length(a::GradedUnitRangeDual{<:Any,<:LabelledUnitRange})
return dual(length(nondual(a)))
end
function Base.iterate(a::GradedUnitRangeDual, i)
i == last(a) && return nothing
return dual.(iterate(nondual(a), i))
end
# TODO: Is this a good definition?
Base.unitrange(a::GradedUnitRangeDual) = a

using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel
using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock
Expand Down

0 comments on commit e716903

Please sign in to comment.