Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jun 3, 2024
1 parent 64e5597 commit e948603
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
5 changes: 5 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/unitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ function unitrangedual_getindices_blocks(a, indices)
return mortar([dual(b) for b in blocks(a_indices)])
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::UnitRangeDual, indices::Block{1})
return a[indices]
end

function Base.getindex(a::UnitRangeDual, indices::Vector{<:Block{1}})
return unitrangedual_getindices_blocks(a, indices)
end
Expand Down
29 changes: 29 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,32 @@ Base.:-(x::LabelledInteger) = labelled_minus(x)
# TODO: This is only needed for older Julia versions, like Julia 1.6.
# Delete once we drop support for older Julia versions.
Base.hash(x::LabelledInteger, h::UInt64) = labelled_hash(x, h)

using Random: AbstractRNG, default_rng
default_eltype() = Float64
for f in [:rand, :randn]
@eval begin
function Base.$f(
rng::AbstractRNG,
elt::Type{<:Number},
dims::Tuple{LabelledInteger,Vararg{LabelledInteger}},
)
return a = $f(rng, elt, unlabel.(dims))
end
function Base.$f(
rng::AbstractRNG,
elt::Type{<:Number},
dim1::LabelledInteger,
dims::Vararg{LabelledInteger},
)
return $f(rng, elt, (dim1, dims...))
end
Base.$f(elt::Type{<:Number}, dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) =
$f(default_rng(), elt, dims)
Base.$f(elt::Type{<:Number}, dim1::LabelledInteger, dims::Vararg{LabelledInteger}) =
$f(elt, (dim1, dims...))
Base.$f(dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) =
$f(default_eltype(), dims)
Base.$f(dim1::LabelledInteger, dims::Vararg{LabelledInteger}) = $f((dim1, dims...))
end
end
24 changes: 24 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
@eval module $(gensym())
using LinearAlgebra: norm
using NDTensors.LabelledNumbers: LabelledInteger, islabelled, label, labelled, unlabel
using Test: @test, @testset
@testset "LabelledNumbers" begin
Expand Down Expand Up @@ -48,6 +49,29 @@ using Test: @test, @testset
@test one(typeof(x)) == true
@test !islabelled(one(typeof(x)))
end
@testset "randn" begin
d = labelled(2, "x")

a = randn(Float32, d, d)
@test eltype(a) === Float32
@test size(a) == (2, 2)
@test norm(a) > 0

a = rand(Float32, d, d)
@test eltype(a) === Float32
@test size(a) == (2, 2)
@test norm(a) > 0

a = randn(d, d)
@test eltype(a) === Float64
@test size(a) == (2, 2)
@test norm(a) > 0

a = rand(d, d)
@test eltype(a) === Float64
@test size(a) == (2, 2)
@test norm(a) > 0
end
@testset "Labelled array ($a)" for a in (collect(2:5), 2:5)
x = labelled(a, "x")
@test eltype(x) == LabelledInteger{Int,String}
Expand Down

0 comments on commit e948603

Please sign in to comment.