Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Oct 28, 2024
1 parent 8a353bd commit aa1b655
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using BlockArrays:
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
using NDTensors.GradedAxes:
GradedAxes, GradedOneTo, GradedUnitRangeDual, blocklabels, dual, gradedrange
GradedAxes,
GradedOneTo,
GradedUnitRange,
GradedUnitRangeDual,
blocklabels,
dual,
gradedrange
using NDTensors.LabelledNumbers: label
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: fusedims, splitdims
Expand Down Expand Up @@ -147,8 +154,50 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
end

@testset "GradedOneTo" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedOneTo
@test axes(a[:, :], i) isa GradedOneTo
end

I = [Block(1)[1:1]]
@test a[I, :] isa AbstractBlockArray
@test a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test !GradedAxes.isdual(axes(a[I, I], 1))
end

@testset "GradedUnitRange" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedUnitRange
@test_broken axes(a[:, :], i) isa GradedUnitRange
end

I = [Block(1)[1:1]]
@test_broken a[I, :] isa AbstractBlockArray
@test_broken a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
end

# Test case when all axes are dual.
@testset "BlockedOneTo" begin
@testset "dual BlockedOneTo" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
Expand All @@ -162,13 +211,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test_broken axes(a[:, :], i) isa GradedUnitRangeDual
end
I = [Block(1)[1:1]]
@test_broken a[I, :]
@test_broken a[:, I]
@test_broken a[I, :] isa AbstractBlockArray
@test_broken a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
end

@testset "GradedUnitRange" begin
@testset "dual GradedUnitRange" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
Expand All @@ -183,13 +232,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
end

I = [Block(1)[1:1]]
@test_broken a[I, :]
@test_broken a[:, I]
@test_broken a[I, :] isa AbstractBlockArray
@test_broken a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
end

@testset "BlockedUnitRange" begin # self dual
@testset "dual BlockedUnitRange" begin # self dual
r = blockedrange([2, 2])
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
Expand All @@ -211,9 +260,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test !GradedAxes.isdual(axes(a[I, I], 1))
end

# Test case when all axes are dual
# from taking the adjoint.
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
# Test case when all axes are dual from taking the adjoint.
for r in (
gradedrange([U1(0) => 2, U1(1) => 2]),
gradedrange([U1(0) => 2, U1(1) => 2])[begin:end],
)
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
Expand All @@ -226,9 +277,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
end

I = [Block(1)[1:1]]
@test size(a[I, :]) == (1, 4)
@test size(a[:, I]) == (4, 1)
@test size(a[I, I]) == (1, 1)
@test_broken size(b[I, :]) == (1, 4)
@test_broken size(b[:, I]) == (4, 1)
@test size(b[I, I]) == (1, 1)
end
end
@testset "Matrix multiplication" begin
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using BlockArrays:
BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock
AbstractBlockedUnitRange,
BlockArrays,
Block,
BlockIndexRange,
BlockedVector,
blocklength,
blocksize,
viewblock

# This splits `BlockIndexRange{N}` into
# `NTuple{N,BlockIndexRange{1}}`.
Expand Down Expand Up @@ -191,7 +198,9 @@ function to_blockindexrange(
# work right now.
return blocks(a.blocks)[Int(I)]
end
function to_blockindexrange(a::Base.Slice{<:BlockedOneTo{<:Integer}}, I::Block{1})
function to_blockindexrange(
a::Base.Slice{<:AbstractBlockedUnitRange{<:Integer}}, I::Block{1}
)
@assert I in only(blockaxes(a.indices))
return I
end
Expand Down

0 comments on commit aa1b655

Please sign in to comment.