From 29c5fbe93cde04f6e500b300dc090d07cda3a44c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 30 Jun 2024 15:03:19 -0400 Subject: [PATCH] Fix some broken tests --- .../BlockArraysExtensions.jl | 16 ++++++++++++++++ .../lib/BlockSparseArrays/test/test_basics.jl | 18 +++++++++--------- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 7641060f09..c8f8a87c5e 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -52,6 +52,22 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}}) return BlockIndices(subblocks, subindices) end +# Used when performing slices like: +# @views a[[Block(2), Block(1)]][2:4, 2:4] +function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockVector{<:BlockIndex{1}}}) + subblocks = mortar( + map(blocks(i.block)) do br + return S.blocks[Int(Block(br))][only(br.indices)] + end, + ) + subindices = mortar( + map(blocks(i.block)) do br + S.indices[br] + end, + ) + return BlockIndices(subblocks, subindices) +end + # Similar to the definition of `BlockArrays.BlockSlices`: # ```julia # const BlockSlices = Union{Base.Slice,BlockSlice{<:BlockRange{1}}} diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index a66448e29f..e73d61847c 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -650,18 +650,18 @@ include("TestBlockSparseArraysUtils.jl") @test block_nstored(c) == 2 @test blocksize(c) == (2, 2) @test blocklengths.(axes(c)) == ([2, 3], [2, 3]) - @test_broken size(c[Block(1, 1)]) == (2, 2) - @test_broken c[Block(1, 1)] == a[Block(2, 2)[2:3, 2:3]] - @test_broken size(c[Block(2, 2)]) == (3, 3) - @test_broken c[Block(2, 2)] == a[Block(1, 1)[1:3, 1:3]] - @test_broken size(c[Block(2, 1)]) == (3, 2) - @test_broken iszero(c[Block(2, 1)]) - @test_broken size(c[Block(1, 2)]) == (2, 3) - @test_broken iszero(c[Block(1, 2)]) + @test size(c[Block(1, 1)]) == (2, 2) + @test c[Block(1, 1)] == a[Block(2, 2)[2:3, 2:3]] + @test size(c[Block(2, 2)]) == (3, 3) + @test c[Block(2, 2)] == a[Block(1, 1)[1:3, 1:3]] + @test size(c[Block(2, 1)]) == (3, 2) + @test iszero(c[Block(2, 1)]) + @test size(c[Block(1, 2)]) == (2, 3) + @test iszero(c[Block(1, 2)]) x = randn(elt, 3, 3) c[Block(2, 2)] = x - @test_broken c[Block(2, 2)] == x + @test c[Block(2, 2)] == x @test a[Block(1, 1)[1:3, 1:3]] == x a = BlockSparseArray{elt}([2, 3], [3, 4])