diff --git a/NDTensors/src/lib/GradedAxes/src/fusion.jl b/NDTensors/src/lib/GradedAxes/src/fusion.jl index 2506e67393..320244ea75 100644 --- a/NDTensors/src/lib/GradedAxes/src/fusion.jl +++ b/NDTensors/src/lib/GradedAxes/src/fusion.jl @@ -12,19 +12,21 @@ function tensor_product( return foldl(tensor_product, (a1, a2, a3, a_rest...)) end -function tensor_product(::AbstractUnitRange, ::AbstractUnitRange) - return error("Not implemented yet.") +flip_dual(r::AbstractUnitRange) = r +flip_dual(r::GradedUnitRangeDual) = flip(r) +function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange) + return tensor_product(flip_dual(a1), flip_dual(a2)) end function tensor_product(a1::Base.OneTo, a2::Base.OneTo) return Base.OneTo(length(a1) * length(a2)) end -function tensor_product(::OneToOne, a2::AbstractBlockedUnitRange) +function tensor_product(::OneToOne, a2::AbstractUnitRange) return a2 end -function tensor_product(a1::AbstractBlockedUnitRange, ::OneToOne) +function tensor_product(a1::AbstractUnitRange, ::OneToOne) return a1 end @@ -32,19 +34,6 @@ function tensor_product(::OneToOne, ::OneToOne) return OneToOne() end -# Handle dual. Always return a non-dual GradedUnitRange. -function tensor_product(a1::AbstractBlockedUnitRange, a2::GradedUnitRangeDual) - return tensor_product(a1, flip(a2)) -end - -function tensor_product(a1::GradedUnitRangeDual, a2::AbstractBlockedUnitRange) - return tensor_product(flip(a1), a2) -end - -function tensor_product(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) - return tensor_product(flip(a1), flip(a2)) -end - function fuse_labels(x, y) return error( "`fuse_labels` not implemented for object of type `$(typeof(x))` and `$(typeof(y))`." @@ -98,7 +87,8 @@ end # Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`. # Get the permutation for sorting, then group by common elements. # groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]] -function blockmergesortperm(a::AbstractBlockedUnitRange) +blockmergesort(g::AbstractUnitRange) = g +function blockmergesortperm(a::AbstractUnitRange) return Block.(groupsortperm(blocklabels(a))) end @@ -120,7 +110,6 @@ function blockmergesort(g::AbstractGradedUnitRange) end blockmergesort(g::GradedUnitRangeDual) = dual(blockmergesort(flip(g))) -blockmergesort(g::OneToOne) = g # fusion_product produces a sorted, non-dual GradedUnitRange function fusion_product(g1, g2) diff --git a/NDTensors/src/lib/GradedAxes/src/onetoone.jl b/NDTensors/src/lib/GradedAxes/src/onetoone.jl index 61ee3c2096..426df396b1 100644 --- a/NDTensors/src/lib/GradedAxes/src/onetoone.jl +++ b/NDTensors/src/lib/GradedAxes/src/onetoone.jl @@ -6,28 +6,4 @@ struct OneToOne{T} <: AbstractUnitRange{T} end OneToOne() = OneToOne{Bool}() Base.first(a::OneToOne) = one(eltype(a)) Base.last(a::OneToOne) = one(eltype(a)) - -# == is just a range comparison that ignores labels. Need dedicated function to check equality. -gradedisequal(::AbstractBlockedUnitRange, ::AbstractUnitRange) = false -gradedisequal(::AbstractUnitRange, ::AbstractBlockedUnitRange) = false -gradedisequal(::AbstractBlockedUnitRange, ::OneToOne) = false -gradedisequal(::OneToOne, ::AbstractBlockedUnitRange) = false -function gradedisequal(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) - return blockisequal(a1, a2) -end -function gradedisequal(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange) - return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2)) -end -gradedisequal(::GradedUnitRangeDual, ::GradedUnitRange) = false -gradedisequal(::GradedUnitRange, ::GradedUnitRangeDual) = false -function gradedisequal(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) - return gradedisequal(nondual(a1), nondual(a2)) -end - -gradedisequal(::OneToOne, ::OneToOne) = true - -function gradedisequal(::OneToOne, g::AbstractUnitRange) - return !islabelled(eltype(g)) && (first(g) == last(g) == 1) -end -gradedisequal(g::AbstractUnitRange, a0::OneToOne) = gradedisequal(a0, g) -gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange) = a1 == a2 +BlockArrays.blockaxes(g::OneToOne) = (Block.(g),) # BlockArrays default crashes for OneToOne{Bool} diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 430b84a368..43dc53302d 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -9,8 +9,7 @@ using BlockArrays: blocklength, blocklengths, blocks -using NDTensors.GradedAxes: - GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedisequal, gradedrange +using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel using Test: @test, @test_broken, @testset @@ -20,13 +19,14 @@ using Test: @test, @test_broken, @testset @test a0 isa OneToOne{Bool} @test eltype(a0) == Bool @test length(a0) == 1 - @test gradedisequal(a0, a0) + @test labelled_isequal(a0, a0) - @test gradedisequal(a0, 1:1) - @test gradedisequal(1:1, a0) - @test !gradedisequal(a0, 1:2) - @test !gradedisequal(1:2, a0) + @test labelled_isequal(a0, 1:1) + @test labelled_isequal(1:1, a0) + @test !labelled_isequal(a0, 1:2) + @test !labelled_isequal(1:2, a0) end + @testset "GradedAxes basics" begin a0 = OneToOne() for a in ( @@ -35,10 +35,10 @@ end gradedrange(["x" => 2, "y" => 3]), ) @test a isa GradedOneTo - @test gradedisequal(a, a) - @test !gradedisequal(a0, a) - @test !gradedisequal(a, a0) - @test !gradedisequal(a, 1:5) + @test labelled_isequal(a, a) + @test !labelled_isequal(a0, a) + @test !labelled_isequal(a, a0) + @test !labelled_isequal(a, 1:5) for x in iterate(a) @test x == 1 @test label(x) == "x" diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 80c218adf2..a0ca3bdf49 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -5,6 +5,7 @@ using BlockArrays: blockaxes, blockedrange, blockfirsts, + blockisequal, blocklasts, blocklength, blocklengths, @@ -23,7 +24,7 @@ using NDTensors.GradedAxes: gradedrange, isdual, nondual -using NDTensors.LabelledNumbers: LabelledInteger, label, labelled +using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal using Test: @test, @testset struct U1 n::Int @@ -36,6 +37,7 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n @test !isdual(a0) @test dual(a0) isa OneToOne @test space_isequal(a0, a0) + @test labelled_isequal(a0, a0) @test space_isequal(a0, dual(a0)) a = 1:3 @@ -50,7 +52,7 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n @test !isdual(a) @test !isdual(ad) @test ad isa BlockedOneTo - @test space_isequal(ad, a) + @test blockisequal(ad, a) end @testset "GradedUnitRangeDual" begin diff --git a/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl b/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl index 02435b5ba7..99e41454ff 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl @@ -11,11 +11,12 @@ using NDTensors.GradedAxes: fusion_product, flip, gradedrange, - labelled_isequal, space_isequal, isdual, tensor_product +using NDTensors.LabelledNumbers: labelled_isequal + struct U1 n::Int end @@ -27,6 +28,7 @@ GradedAxes.fuse_labels(x::U1, y::U1) = U1(x.n + y.n) GradedAxes.fuse_labels(x::String, y::String) = x * y g0 = OneToOne() + @test labelled_isequal(g0, g0) @test labelled_isequal(tensor_product(g0, g0), g0) a = gradedrange(["x" => 2, "y" => 3])