diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index 3ce65a4d15..8f1a86f1b8 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -38,6 +38,11 @@ function unitrangedual_getindices_blocks(a, indices) return mortar([dual(b) for b in blocks(a_indices)]) end +# TODO: Move this to a `BlockArraysExtensions` library. +function blockedunitrange_getindices(a::UnitRangeDual, indices::Block{1}) + return a[indices] +end + function Base.getindex(a::UnitRangeDual, indices::Vector{<:Block{1}}) return unitrangedual_getindices_blocks(a, indices) end diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl index f5e2d58f3d..323d252b0c 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl @@ -86,3 +86,32 @@ Base.:-(x::LabelledInteger) = labelled_minus(x) # TODO: This is only needed for older Julia versions, like Julia 1.6. # Delete once we drop support for older Julia versions. Base.hash(x::LabelledInteger, h::UInt64) = labelled_hash(x, h) + +using Random: AbstractRNG, default_rng +default_eltype() = Float64 +for f in [:rand, :randn] + @eval begin + function Base.$f( + rng::AbstractRNG, + elt::Type{<:Number}, + dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}, + ) + return a = $f(rng, elt, unlabel.(dims)) + end + function Base.$f( + rng::AbstractRNG, + elt::Type{<:Number}, + dim1::LabelledInteger, + dims::Vararg{LabelledInteger}, + ) + return $f(rng, elt, (dim1, dims...)) + end + Base.$f(elt::Type{<:Number}, dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) = + $f(default_rng(), elt, dims) + Base.$f(elt::Type{<:Number}, dim1::LabelledInteger, dims::Vararg{LabelledInteger}) = + $f(elt, (dim1, dims...)) + Base.$f(dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) = + $f(default_eltype(), dims) + Base.$f(dim1::LabelledInteger, dims::Vararg{LabelledInteger}) = $f((dim1, dims...)) + end +end diff --git a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl index cf3f87e86d..6fc1ac4231 100644 --- a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl +++ b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl @@ -1,4 +1,5 @@ @eval module $(gensym()) +using LinearAlgebra: norm using NDTensors.LabelledNumbers: LabelledInteger, islabelled, label, labelled, unlabel using Test: @test, @testset @testset "LabelledNumbers" begin @@ -48,6 +49,29 @@ using Test: @test, @testset @test one(typeof(x)) == true @test !islabelled(one(typeof(x))) end + @testset "randn" begin + d = labelled(2, "x") + + a = randn(Float32, d, d) + @test eltype(a) === Float32 + @test size(a) == (2, 2) + @test norm(a) > 0 + + a = rand(Float32, d, d) + @test eltype(a) === Float32 + @test size(a) == (2, 2) + @test norm(a) > 0 + + a = randn(d, d) + @test eltype(a) === Float64 + @test size(a) == (2, 2) + @test norm(a) > 0 + + a = rand(d, d) + @test eltype(a) === Float64 + @test size(a) == (2, 2) + @test norm(a) > 0 + end @testset "Labelled array ($a)" for a in (collect(2:5), 2:5) x = labelled(a, "x") @test eltype(x) == LabelledInteger{Int,String}