Skip to content

Commit

Permalink
Fix setting block to zero with broadcasting notation
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jun 21, 2024
1 parent d2b1104 commit b10ece8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,22 @@ function Base.setindex!(a::BlockSparseArrayLike{<:Any,1}, value, I::Block{1})
return a
end

function Base.fill!(a::AbstractBlockSparseArray, value)
if iszero(value)
# This drops all of the blocks.
sparse_zero!(blocks(a))
return a
end
blocksparse_fill!(a, value)
return a
end

function Base.fill!(a::BlockSparseArrayLike, value)
# TODO: Even if `iszero(value)`, this doesn't drop
# blocks from `a`, and additionally allocates
# new blocks filled with zeros, unlike
# `fill!(a::AbstractBlockSparseArray, value)`.
# Consider changing that behavior when possible.
blocksparse_fill!(a, value)
return a
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,6 @@ function blocksparse_setindex!(
end

function blocksparse_fill!(a::AbstractArray, value)
if iszero(value)
# This drops all of the blocks.
sparse_zero!(blocks(a))
return a
end
for b in BlockRange(a)
# We can't use:
# ```julia
Expand Down
10 changes: 10 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ include("TestBlockSparseArraysUtils.jl")
@test block_nstored(a) == 1
@test nstored(a) == 2 * 4

a = BlockSparseArray{elt}([2, 3], [3, 4])
a[Block(1, 2)] .= 0
@test eltype(a) == elt
@test iszero(a[Block(1, 1)])
@test iszero(a[Block(2, 1)])
@test iszero(a[Block(1, 2)])
@test iszero(a[Block(2, 2)])
@test block_nstored(a) == 1
@test nstored(a) == 2 * 4

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
@views for b in [Block(1, 2), Block(2, 1)]
a[b] = randn(elt, size(a[b]))
Expand Down

0 comments on commit b10ece8

Please sign in to comment.