Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define map_blocklabels #5

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,4 @@ nondual_type(x) = nondual_type(typeof(x))
nondual_type(T::Type) = T

dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))
label_dual(x) = label_dual(LabelledStyle(x), x)
label_dual(::NotLabelled, x) = x
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))

flip(a::AbstractUnitRange) = dual(label_dual(a))
flip(g::AbstractGradedUnitRange) = dual(gradedrange(label_dual.(blocklengths(g))))
flip(a::AbstractUnitRange) = dual(map_blocklabels(dual, a))
6 changes: 6 additions & 0 deletions src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,9 @@ function blockedunitrange_getindices(
# if `a isa `GradedUnitRange`, for example.
return mortar(blks, labelled_length.(blks))
end

map_blocklabels(::Any, a::AbstractUnitRange) = a
function map_blocklabels(f, g::AbstractGradedUnitRange)
# use labelled_blocks to preserve GradedUnitRange
return labelled_blocks(unlabel_blocks(g), f.(blocklabels(g)))
end
16 changes: 10 additions & 6 deletions src/gradedunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
## TODO: Define this to instantiate a dual unit range.
## materialize_dual(a::GradedUnitRangeDual) = materialize_dual(nondual(a))

Base.first(a::GradedUnitRangeDual) = label_dual(first(nondual(a)))
Base.last(a::GradedUnitRangeDual) = label_dual(last(nondual(a)))
Base.step(a::GradedUnitRangeDual) = label_dual(step(nondual(a)))
Base.first(a::GradedUnitRangeDual) = dual(first(nondual(a)))
Base.last(a::GradedUnitRangeDual) = dual(last(nondual(a)))
Base.step(a::GradedUnitRangeDual) = dual(step(nondual(a)))

Check warning on line 32 in src/gradedunitrangedual.jl

View check run for this annotation

Codecov / codecov/patch

src/gradedunitrangedual.jl#L32

Added line #L32 was not covered by tests

Base.view(a::GradedUnitRangeDual, index::Block{1}) = a[index]

Expand All @@ -40,7 +40,7 @@
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer)
return label_dual(getindex(nondual(a), indices))
return dual(getindex(nondual(a), indices))

Check warning on line 43 in src/gradedunitrangedual.jl

View check run for this annotation

Codecov / codecov/patch

src/gradedunitrangedual.jl#L43

Added line #L43 was not covered by tests
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1})
Expand Down Expand Up @@ -123,8 +123,8 @@
end

BlockArrays.blockaxes(a::GradedUnitRangeDual) = blockaxes(nondual(a))
BlockArrays.blockfirsts(a::GradedUnitRangeDual) = label_dual.(blockfirsts(nondual(a)))
BlockArrays.blocklasts(a::GradedUnitRangeDual) = label_dual.(blocklasts(nondual(a)))
BlockArrays.blockfirsts(a::GradedUnitRangeDual) = dual.(blockfirsts(nondual(a)))
BlockArrays.blocklasts(a::GradedUnitRangeDual) = dual.(blocklasts(nondual(a)))
function BlockArrays.findblock(a::GradedUnitRangeDual, index::Integer)
return findblock(nondual(a), index)
end
Expand All @@ -138,3 +138,7 @@
function unlabel_blocks(a::GradedUnitRangeDual)
return unlabel_blocks(nondual(a))
end

function map_blocklabels(f, g::GradedUnitRangeDual)
return dual(map_blocklabels(f, dual(g)))

Check warning on line 143 in src/gradedunitrangedual.jl

View check run for this annotation

Codecov / codecov/patch

src/gradedunitrangedual.jl#L142-L143

Added lines #L142 - L143 were not covered by tests
end
4 changes: 3 additions & 1 deletion src/labelledunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ end
dual(a::LabelledUnitRange) = LabelledUnitRangeDual(a)
nondual(a::LabelledUnitRangeDual) = a.nondual_unitrange
dual(a::LabelledUnitRangeDual) = nondual(a)
label_dual(::IsLabelled, a::LabelledUnitRangeDual) = dual(label_dual(nondual(a)))
isdual(::LabelledUnitRangeDual) = true
blocklabels(la::LabelledUnitRangeDual) = [label(la)]

map_blocklabels(f, la::LabelledUnitRange) = labelled(unlabel(la), f(label(la)))
map_blocklabels(f, lad::LabelledUnitRangeDual) = dual(map_blocklabels(f, nondual(lad)))

function nondual_type(
::Type{<:LabelledUnitRangeDual{<:Any,NondualUnitRange}}
) where {NondualUnitRange}
Expand Down
7 changes: 7 additions & 0 deletions test/test_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,24 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n

a = 1:3
ad = dual(a)
af = flip(a)
@test !isdual(a)
@test !isdual(ad)
@test !isdual(af)
@test ad isa UnitRange
@test af isa UnitRange
@test space_isequal(ad, a)
@test space_isequal(af, a)

a = blockedrange([2, 3])
ad = dual(a)
af = flip(a)
@test !isdual(a)
@test !isdual(ad)
@test ad isa BlockedOneTo
@test af isa BlockedOneTo
@test blockisequal(ad, a)
@test blockisequal(af, a)
end

@testset "LabelledUnitRangeDual" begin
Expand Down
Loading