Skip to content

Commit

Permalink
passing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Oct 28, 2024
1 parent cc6d7ea commit 3f6bd2e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 58 deletions.
27 changes: 8 additions & 19 deletions NDTensors/src/lib/GradedAxes/src/fusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,28 @@ function tensor_product(
return foldl(tensor_product, (a1, a2, a3, a_rest...))
end

function tensor_product(::AbstractUnitRange, ::AbstractUnitRange)
return error("Not implemented yet.")
flip_dual(r::AbstractUnitRange) = r
flip_dual(r::GradedUnitRangeDual) = flip(r)
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
return tensor_product(flip_dual(a1), flip_dual(a2))
end

function tensor_product(a1::Base.OneTo, a2::Base.OneTo)
return Base.OneTo(length(a1) * length(a2))
end

function tensor_product(::OneToOne, a2::AbstractBlockedUnitRange)
function tensor_product(::OneToOne, a2::AbstractUnitRange)
return a2
end

function tensor_product(a1::AbstractBlockedUnitRange, ::OneToOne)
function tensor_product(a1::AbstractUnitRange, ::OneToOne)
return a1
end

function tensor_product(::OneToOne, ::OneToOne)
return OneToOne()
end

# Handle dual. Always return a non-dual GradedUnitRange.
function tensor_product(a1::AbstractBlockedUnitRange, a2::GradedUnitRangeDual)
return tensor_product(a1, flip(a2))
end

function tensor_product(a1::GradedUnitRangeDual, a2::AbstractBlockedUnitRange)
return tensor_product(flip(a1), a2)
end

function tensor_product(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual)
return tensor_product(flip(a1), flip(a2))
end

function fuse_labels(x, y)
return error(
"`fuse_labels` not implemented for object of type `$(typeof(x))` and `$(typeof(y))`."
Expand Down Expand Up @@ -98,7 +87,8 @@ end
# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`.
# Get the permutation for sorting, then group by common elements.
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
function blockmergesortperm(a::AbstractBlockedUnitRange)
blockmergesort(g::AbstractUnitRange) = g
function blockmergesortperm(a::AbstractUnitRange)
return Block.(groupsortperm(blocklabels(a)))
end

Expand All @@ -120,7 +110,6 @@ function blockmergesort(g::AbstractGradedUnitRange)
end

blockmergesort(g::GradedUnitRangeDual) = dual(blockmergesort(flip(g)))
blockmergesort(g::OneToOne) = g

# fusion_product produces a sorted, non-dual GradedUnitRange
function fusion_product(g1, g2)
Expand Down
26 changes: 1 addition & 25 deletions NDTensors/src/lib/GradedAxes/src/onetoone.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,4 @@ struct OneToOne{T} <: AbstractUnitRange{T} end
OneToOne() = OneToOne{Bool}()
Base.first(a::OneToOne) = one(eltype(a))
Base.last(a::OneToOne) = one(eltype(a))

# == is just a range comparison that ignores labels. Need dedicated function to check equality.
gradedisequal(::AbstractBlockedUnitRange, ::AbstractUnitRange) = false
gradedisequal(::AbstractUnitRange, ::AbstractBlockedUnitRange) = false
gradedisequal(::AbstractBlockedUnitRange, ::OneToOne) = false
gradedisequal(::OneToOne, ::AbstractBlockedUnitRange) = false
function gradedisequal(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
return blockisequal(a1, a2)
end
function gradedisequal(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange)
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))
end
gradedisequal(::GradedUnitRangeDual, ::GradedUnitRange) = false
gradedisequal(::GradedUnitRange, ::GradedUnitRangeDual) = false
function gradedisequal(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual)
return gradedisequal(nondual(a1), nondual(a2))
end

gradedisequal(::OneToOne, ::OneToOne) = true

function gradedisequal(::OneToOne, g::AbstractUnitRange)
return !islabelled(eltype(g)) && (first(g) == last(g) == 1)
end
gradedisequal(g::AbstractUnitRange, a0::OneToOne) = gradedisequal(a0, g)
gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange) = a1 == a2
BlockArrays.blockaxes(g::OneToOne) = (Block.(g),) # BlockArrays default crashes for OneToOne{Bool}
22 changes: 11 additions & 11 deletions NDTensors/src/lib/GradedAxes/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ using BlockArrays:
blocklength,
blocklengths,
blocks
using NDTensors.GradedAxes:
GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedisequal, gradedrange
using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange
using NDTensors.LabelledNumbers:
LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel
using Test: @test, @test_broken, @testset
Expand All @@ -20,13 +19,14 @@ using Test: @test, @test_broken, @testset
@test a0 isa OneToOne{Bool}
@test eltype(a0) == Bool
@test length(a0) == 1
@test gradedisequal(a0, a0)
@test labelled_isequal(a0, a0)

@test gradedisequal(a0, 1:1)
@test gradedisequal(1:1, a0)
@test !gradedisequal(a0, 1:2)
@test !gradedisequal(1:2, a0)
@test labelled_isequal(a0, 1:1)
@test labelled_isequal(1:1, a0)
@test !labelled_isequal(a0, 1:2)
@test !labelled_isequal(1:2, a0)
end

@testset "GradedAxes basics" begin
a0 = OneToOne()
for a in (
Expand All @@ -35,10 +35,10 @@ end
gradedrange(["x" => 2, "y" => 3]),
)
@test a isa GradedOneTo
@test gradedisequal(a, a)
@test !gradedisequal(a0, a)
@test !gradedisequal(a, a0)
@test !gradedisequal(a, 1:5)
@test labelled_isequal(a, a)
@test !labelled_isequal(a0, a)
@test !labelled_isequal(a, a0)
@test !labelled_isequal(a, 1:5)
for x in iterate(a)
@test x == 1
@test label(x) == "x"
Expand Down
6 changes: 4 additions & 2 deletions NDTensors/src/lib/GradedAxes/test/test_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using BlockArrays:
blockaxes,
blockedrange,
blockfirsts,
blockisequal,
blocklasts,
blocklength,
blocklengths,
Expand All @@ -23,7 +24,7 @@ using NDTensors.GradedAxes:
gradedrange,
isdual,
nondual
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal
using Test: @test, @testset
struct U1
n::Int
Expand All @@ -36,6 +37,7 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
@test !isdual(a0)
@test dual(a0) isa OneToOne
@test space_isequal(a0, a0)
@test labelled_isequal(a0, a0)
@test space_isequal(a0, dual(a0))

a = 1:3
Expand All @@ -50,7 +52,7 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
@test !isdual(a)
@test !isdual(ad)
@test ad isa BlockedOneTo
@test space_isequal(ad, a)
@test blockisequal(ad, a)
end

@testset "GradedUnitRangeDual" begin
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ using NDTensors.GradedAxes:
fusion_product,
flip,
gradedrange,
labelled_isequal,
space_isequal,
isdual,
tensor_product

using NDTensors.LabelledNumbers: labelled_isequal

struct U1
n::Int
end
Expand All @@ -27,6 +28,7 @@ GradedAxes.fuse_labels(x::U1, y::U1) = U1(x.n + y.n)
GradedAxes.fuse_labels(x::String, y::String) = x * y

g0 = OneToOne()
@test labelled_isequal(g0, g0)
@test labelled_isequal(tensor_product(g0, g0), g0)

a = gradedrange(["x" => 2, "y" => 3])
Expand Down

0 comments on commit 3f6bd2e

Please sign in to comment.