Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Jun 14, 2024
1 parent a86e67e commit 3b4e3cb
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion NDTensors/src/lib/GradedAxes/src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ label_dual(::NotLabelled, x) = x
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))

# TBD rename deepdual? yet another name?
label_dual(g::GradedUnitRange) = gradedrange(label_dual.(blocklengths(g)))
label_dual(g::AbstractGradedUnitRange) = gradedrange(label_dual.(blocklengths(g)))
4 changes: 2 additions & 2 deletions NDTensors/src/lib/GradedAxes/src/fusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function fuse_blocklengths(x::LabelledInteger, y::LabelledInteger)
end

flatten_maybe_nested(v::Vector{<:Integer}) = v
flatten_maybe_nested(v::Vector{<:GradedUnitRange}) = reduce(vcat, blocklengths.(v))
flatten_maybe_nested(v::Vector{<:AbstractGradedUnitRange}) = reduce(vcat, blocklengths.(v))

using BlockArrays: blockedrange, blocks
function tensor_product(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
Expand Down Expand Up @@ -117,7 +117,7 @@ function blockmergesortperm(a::UnitRangeDual)
return Block.(groupsortperm(blocklabels(nondual(a))))
end

function blockmergesort(g::GradedUnitRange)
function blockmergesort(g::AbstractGradedUnitRange)
glabels = blocklabels(g)
gblocklengths = blocklengths(g)
new_blocklengths = map(
Expand Down
5 changes: 4 additions & 1 deletion NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ end

# == is just a range comparison that ignores labels. Need dedicated function to check equality.
function gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange)
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))
# TODO remove workaround once BlockArrays.blockisequal is generalized to Integer
blocka1 = BlockArrays.blockedrange(GradedAxes.unlabel.(BlockArrays.blocklengths(a1)))
blocka2 = BlockArrays.blockedrange(GradedAxes.unlabel.(BlockArrays.blocklengths(a2)))
return blockisequal(blocka1, blocka2) && (blocklabels(a1) == blocklabels(a2))
end

# TODO: Use `TypeParameterAccessors`.
Expand Down
6 changes: 3 additions & 3 deletions NDTensors/src/lib/GradedAxes/src/unitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dual(a::AbstractUnitRange) = UnitRangeDual(a)
nondual(a::UnitRangeDual) = a.nondual_unitrange
dual(a::UnitRangeDual) = nondual(a)
nondual(a::AbstractUnitRange) = a
isdual(::GradedUnitRange) = false
isdual(::AbstractGradedUnitRange) = false
isdual(::UnitRangeDual) = true
## TODO: Define this to instantiate a dual unit range.
## materialize_dual(a::UnitRangeDual) = materialize_dual(nondual(a))
Expand Down Expand Up @@ -78,8 +78,8 @@ BlockArrays.findblock(a::UnitRangeDual, index::Integer) = findblock(nondual(a),

blocklabels(a::UnitRangeDual) = dual.(blocklabels(nondual(a)))

gradedisequal(a1::UnitRangeDual, a2::GradedUnitRange) = false
gradedisequal(a1::GradedUnitRange, a2::UnitRangeDual) = false
gradedisequal(::UnitRangeDual, ::AbstractGradedUnitRange) = false
gradedisequal(::AbstractGradedUnitRange, ::UnitRangeDual) = false
function gradedisequal(a1::UnitRangeDual, a2::UnitRangeDual)
return gradedisequal(nondual(a1), nondual(a2))
end
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/lib/GradedAxes/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ using BlockArrays:
blocklengths,
blocks
using NDTensors.BlockSparseArrays: BlockSparseVector
using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, blocklabels, gradedrange
using NDTensors.GradedAxes:
GradedOneTo, GradedUnitRange, blocklabels, gradedisequal, gradedrange
using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, unlabel
using Test: @test, @test_broken, @testset
@testset "GradedAxes basics" begin
Expand Down Expand Up @@ -41,6 +42,7 @@ using Test: @test, @test_broken, @testset
@test label(x) == "y"
end
@test isnothing(iterate(a, labelled(5, "y")))
@test gradedisequal(a, a)
@test length(a) == 5
@test step(a) == 1
@test !islabelled(step(a))
Expand Down
13 changes: 9 additions & 4 deletions NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
@eval module $(gensym())
using NDTensors.GradedAxes: GradedAxes, GradedOneTo
GradedUnitRange, OneToOne, fusion_product, gradedrange, gradedisequal, tensor_product
using NDTensors.GradedAxes:
GradedAxes,
GradedOneTo,
OneToOne,
fusion_product,
gradedrange,
gradedisequal,
tensor_product
using BlockArrays: blocklength, blocklengths
using Test: @test, @testset

Expand All @@ -12,16 +18,15 @@ using Test: @test, @testset

a = gradedrange(["x" => 2, "y" => 3])
b = tensor_product(a, a)
@test b isa GradedUnitRange
@test b isa GradedOneTo
@test length(b) == 25
@test blocklength(b) == 4
@test blocklengths(b) == [4, 6, 6, 9]
@test gradedisequal(b, gradedrange(["xx" => 4, "yx" => 6, "xy" => 6, "yy" => 9]))

c = tensor_product(a, a, a)
@test c isa GradedOneTo
@test length(c) == 125
@test c isa GradedUnitRange
@test blocklength(c) == 8
end

Expand Down

0 comments on commit 3b4e3cb

Please sign in to comment.