From b10ece8bec4fa793e231b8bacbe56159dc13fc96 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 21 Jun 2024 10:19:31 -0400 Subject: [PATCH] Fix setting block to zero with broadcasting notation --- .../wrappedabstractblocksparsearray.jl | 15 +++++++++++++++ .../blocksparsearrayinterface.jl | 5 ----- .../src/lib/BlockSparseArrays/test/test_basics.jl | 10 ++++++++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 8b51619f99..0a81605262 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -185,7 +185,22 @@ function Base.setindex!(a::BlockSparseArrayLike{<:Any,1}, value, I::Block{1}) return a end +function Base.fill!(a::AbstractBlockSparseArray, value) + if iszero(value) + # This drops all of the blocks. + sparse_zero!(blocks(a)) + return a + end + blocksparse_fill!(a, value) + return a +end + function Base.fill!(a::BlockSparseArrayLike, value) + # TODO: Even if `iszero(value)`, this doesn't drop + # blocks from `a`, and additionally allocates + # new blocks filled with zeros, unlike + # `fill!(a::AbstractBlockSparseArray, value)`. + # Consider changing that behavior when possible. blocksparse_fill!(a, value) return a end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 2a4e288d49..29b0ed2f3f 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -96,11 +96,6 @@ function blocksparse_setindex!( end function blocksparse_fill!(a::AbstractArray, value) - if iszero(value) - # This drops all of the blocks. - sparse_zero!(blocks(a)) - return a - end for b in BlockRange(a) # We can't use: # ```julia diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index edc693fe99..08246c363a 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -92,6 +92,16 @@ include("TestBlockSparseArraysUtils.jl") @test block_nstored(a) == 1 @test nstored(a) == 2 * 4 + a = BlockSparseArray{elt}([2, 3], [3, 4]) + a[Block(1, 2)] .= 0 + @test eltype(a) == elt + @test iszero(a[Block(1, 1)]) + @test iszero(a[Block(2, 1)]) + @test iszero(a[Block(1, 2)]) + @test iszero(a[Block(2, 2)]) + @test block_nstored(a) == 1 + @test nstored(a) == 2 * 4 + a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) @views for b in [Block(1, 2), Block(2, 1)] a[b] = randn(elt, size(a[b]))