Skip to content

Commit

Permalink
fix GradedUnitRangeDual tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Oct 31, 2024
1 parent db71962 commit c57bc02
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ isdual(::AbstractUnitRange) = false

using NDTensors.LabelledNumbers:
LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel

dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))
label_dual(x) = label_dual(LabelledStyle(x), x)
label_dual(::NotLabelled, x) = x
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))
Expand Down
2 changes: 0 additions & 2 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ end
Base.unitrange(a::GradedUnitRangeDual) = a

using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel
dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))

using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock
BlockArrays.blockaxes(a::GradedUnitRangeDual) = blockaxes(nondual(a))
BlockArrays.blockfirsts(a::GradedUnitRangeDual) = label_dual.(blockfirsts(nondual(a)))
Expand Down
4 changes: 4 additions & 0 deletions NDTensors/src/lib/GradedAxes/test/test_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using BlockArrays:
blocks,
findblock
using NDTensors.GradedAxes:
AbstractGradedUnitRange,
GradedAxes,
GradedUnitRangeDual,
OneToOne,
Expand Down Expand Up @@ -60,6 +61,7 @@ end
[gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]]
ad = dual(a)
@test ad isa GradedUnitRangeDual
@test ad isa AbstractGradedUnitRange
@test eltype(ad) == LabelledInteger{Int,U1}
@test blocklengths(ad) isa Vector
@test eltype(blocklengths(ad)) == eltype(blocklengths(a))
Expand All @@ -78,6 +80,8 @@ end
@test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))]
@test blocklength(ad) == 2
@test blocklengths(ad) == [2, 3]
@test blocklabels(ad) == [U1(0), U1(-1)]
@test label.(blocklengths(ad)) == [U1(0), U1(-1)]
@test findblock(ad, 4) == Block(2)
@test only(blockaxes(ad)) == Block(1):Block(2)
@test blocks(ad) == [labelled(1:2, U1(0)), labelled(3:5, U1(-1))]
Expand Down

0 comments on commit c57bc02

Please sign in to comment.