Skip to content

Commit

Permalink
[BlockSparseArrays] More general broadcasting and slicing (#1332)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Mar 22, 2024
1 parent 093d339 commit 957f2af
Show file tree
Hide file tree
Showing 25 changed files with 824 additions and 126 deletions.
Original file line number Diff line number Diff line change
@@ -1,26 +1,74 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using BlockArrays: Block, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray
using NDTensors.GradedAxes: gradedrange
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
using NDTensors.GradedAxes: GradedUnitRange, gradedrange
using NDTensors.LabelledNumbers: label
using NDTensors.Sectors: U1
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: fusedims, splitdims
using Random: randn!
function blockdiagonal!(f, a::AbstractArray)
for i in 1:minimum(blocksize(a))
b = Block(ntuple(Returns(i), ndims(a)))
a[b] = f(a[b])
end
return a
end
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "BlockSparseArraysGradedAxesExt (eltype=$elt)" for elt in elts
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
d2 = gradedrange([U1(1) => 1, U1(0) => 1])
a = BlockSparseArray{elt}(d1, d2, d1, d2)
for i in 1:minimum(blocksize(a))
b = Block(i, i, i, i)
a[b] = randn!(a[b])
@testset "map" begin
d1 = gradedrange([U1(0) => 2, U1(1) => 2])
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(d1, d2, d1, d2)
blockdiagonal!(randn!, a)

for b in (a + a, 2 * a)
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test nstored(b) == 32
@test block_nstored(b) == 2
# TODO: Have to investigate why this fails
# on Julia v1.6, or drop support for v1.6.
for i in 1:ndims(a)
@test axes(b, i) isa GradedUnitRange
end
@test label(axes(b, 1)[Block(1)]) == U1(0)
@test label(axes(b, 1)[Block(2)]) == U1(1)
@test Array(a) isa Array{elt}
@test Array(a) == a
@test 2 * Array(a) == b
end

b = a[2:3, 2:3, 2:3, 2:3]
@test size(b) == (2, 2, 2, 2)
@test blocksize(b) == (2, 2, 2, 2)
@test nstored(b) == 2
@test block_nstored(b) == 2
for i in 1:ndims(a)
@test axes(b, i) isa GradedUnitRange
end
@test label(axes(b, 1)[Block(1)]) == U1(0)
@test label(axes(b, 1)[Block(2)]) == U1(1)
@test Array(a) isa Array{elt}
@test Array(a) == a
end
# TODO: Add tests for various slicing operations.
@testset "fusedims" begin
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
d2 = gradedrange([U1(0) => 1, U1(1) => 1])
a = BlockSparseArray{elt}(d1, d2, d1, d2)
blockdiagonal!(randn!, a)
m = fusedims(a, (1, 2), (3, 4))
@test axes(m, 1) isa GradedUnitRange
@test axes(m, 2) isa GradedUnitRange
@test a[1, 1, 1, 1] == m[1, 1]
@test a[2, 2, 2, 2] == m[4, 4]
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocksize(m) == (3, 3)
@test a == splitdims(m, (d1, d2), (d1, d2))
end
m = fusedims(a, (1, 2), (3, 4))
@test a[1, 1, 1, 1] == m[2, 2]
@test a[2, 2, 2, 2] == m[3, 3]
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocksize(m) == (3, 3)
@test a == splitdims(m, (d1, d2), (d1, d2))
end
end
Original file line number Diff line number Diff line change
@@ -1,7 +1,48 @@
using BlockArrays: AbstractBlockArray, AbstractBlockVector, Block, blockedrange
using BlockArrays:
BlockArrays,
AbstractBlockArray,
AbstractBlockVector,
Block,
BlockRange,
BlockedUnitRange,
BlockVector,
block,
blockaxes,
blockedrange,
blockindex,
blocks,
findblock,
findblockindex
using Compat: allequal
using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices

# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices)
return error("Not implemented")
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::AbstractUnitRange)
return only(axes(blockedunitrange_getindices(a, indices)))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::AbstractVector{<:Block})
return blockedrange([length(a[index]) for index in indices])
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# TODO: Merge blocks.
function sub_axis(a::AbstractUnitRange, indices::BlockVector{<:Block})
# `collect` is needed here, otherwise a `PseudoBlockVector` is
# constructed.
return blockedrange([length(a[index]) for index in collect(indices)])
end

# TODO: Use `Tuple` conversion once
# BlockArrays.jl PR is merged.
block_to_cartesianindex(b::Block) = CartesianIndex(b.n)
Expand Down Expand Up @@ -38,3 +79,110 @@ end
function block_reshape(a::AbstractArray, axes::Vararg{AbstractUnitRange})
return block_reshape(a, axes)
end

function cartesianindices(axes::Tuple, b::Block)
return CartesianIndices(ntuple(dim -> axes[dim][Tuple(b)[dim]], length(axes)))
end

# Get the range within a block.
function blockindexrange(axis::AbstractUnitRange, r::UnitRange)
bi1 = findblockindex(axis, first(r))
bi2 = findblockindex(axis, last(r))
b = block(bi1)
# Range must fall within a single block.
@assert b == block(bi2)
i1 = blockindex(bi1)
i2 = blockindex(bi2)
return b[i1:i2]
end

function blockindexrange(
axes::Tuple{Vararg{AbstractUnitRange,N}}, I::CartesianIndices{N}
) where {N}
brs = blockindexrange.(axes, I.indices)
b = Block(block.(brs))
rs = map(br -> only(br.indices), brs)
return b[rs...]
end

function blockindexrange(a::AbstractArray, I::CartesianIndices)
return blockindexrange(axes(a), I)
end

# Get the blocks the range spans across.
function blockrange(axis::AbstractUnitRange, r::UnitRange)
return findblock(axis, first(r)):findblock(axis, last(r))
end

function blockrange(axis::AbstractUnitRange, r::Int)
error("Slicing with integer values isn't supported.")
return findblock(axis, r)
end

function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
for b in r
@assert b blockaxes(axis, 1)
end
return r
end

using BlockArrays: BlockSlice
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
end

function blockrange(axis::AbstractUnitRange, r)
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end

function cartesianindices(a::AbstractArray, b::Block)
return cartesianindices(axes(a), b)
end

# Output which blocks of `axis` are contained within the unit range `range`.
# The start and end points must match.
function findblocks(axis::AbstractUnitRange, range::AbstractUnitRange)
# TODO: Add a test that the start and end points of the ranges match.
return findblock(axis, first(range)):findblock(axis, last(range))
end

function block_stored_indices(a::AbstractArray)
return Block.(Tuple.(stored_indices(blocks(a))))
end

_block(indices) = block(indices)
_block(indices::CartesianIndices) = Block(ntuple(Returns(1), ndims(indices)))

function combine_axes(as::Vararg{Tuple})
@assert allequal(length.(as))
ndims = length(first(as))
return ntuple(ndims) do dim
dim_axes = map(a -> a[dim], as)
return reduce(BlockArrays.combine_blockaxes, dim_axes)
end
end

# Returns `BlockRange`
# Convert the block of the axes to blocks of the subaxes.
function subblocks(axes::Tuple, subaxes::Tuple, block::Block)
@assert length(axes) == length(subaxes)
return BlockRange(
ntuple(length(axes)) do dim
findblocks(subaxes[dim], axes[dim][Tuple(block)[dim]])
end,
)
end

# Returns `Vector{<:Block}`
function subblocks(axes::Tuple, subaxes::Tuple, blocks)
return mapreduce(vcat, blocks; init=eltype(blocks)[]) do block
return vec(subblocks(axes, subaxes, block))
end
end

# Returns `Vector{<:CartesianIndices}`
function blocked_cartesianindices(axes::Tuple, subaxes::Tuple, blocks)
return map(subblocks(axes, subaxes, blocks)) do block
return cartesianindices(subaxes, block)
end
end
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
using BlockArrays: AbstractBlockArray, BlocksView
using ..SparseArrayInterface: SparseArrayInterface, nstored

function SparseArrayInterface.nstored(a::AbstractBlockArray)
return sum(b -> nstored(b), blocks(a); init=zero(Int))
end

# TODO: Handle `BlocksView` wrapping a sparse array?
function SparseArrayInterface.storage_indices(a::BlocksView)
return CartesianIndices(a)
end
2 changes: 2 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
include("blocksparsearrayinterface/linearalgebra.jl")
include("blocksparsearrayinterface/blockzero.jl")
include("blocksparsearrayinterface/broadcast.jl")
include("blocksparsearrayinterface/arraylayouts.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
include("abstractblocksparsearray/abstractblocksparsevector.jl")
include("abstractblocksparsearray/view.jl")
include("abstractblocksparsearray/arraylayouts.jl")
include("abstractblocksparsearray/sparsearrayinterface.jl")
include("abstractblocksparsearray/linearalgebra.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,26 @@ Base.axes(::AbstractBlockSparseArray) = error("Not implemented")

blockstype(::Type{<:AbstractBlockSparseArray}) = error("Not implemented")

# Specialized in order to fix ambiguity error with `BlockArrays`.
## # Specialized in order to fix ambiguity error with `BlockArrays`.
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
return blocksparse_getindex(a, I...)
end

# Fix ambiguity error with `BlockArrays`.
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N}
return ArrayLayouts.layout_getindex(a, I)
end

# Fix ambiguity error with `BlockArrays`.
function Base.getindex(a::AbstractBlockSparseArray{<:Any,1}, I::Block{1})
return ArrayLayouts.layout_getindex(a, I)
end

# Fix ambiguity error with `BlockArrays`.
function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector})
return blocksparse_getindex(a, I...)
end
## # Fix ambiguity error with `BlockArrays`.
## function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N}
## return ArrayLayouts.layout_getindex(a, I)
## end
##
## # Fix ambiguity error with `BlockArrays`.
## function Base.getindex(a::AbstractBlockSparseArray{<:Any,1}, I::Block{1})
## return ArrayLayouts.layout_getindex(a, I)
## end
##
## # Fix ambiguity error with `BlockArrays`.
## function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector})
## ## return blocksparse_getindex(a, I...)
## return ArrayLayouts.layout_getindex(a, I...)
## end

# Specialized in order to fix ambiguity error with `BlockArrays`.
function Base.setindex!(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
using ArrayLayouts: ArrayLayouts, MemoryLayout, MatMulMatAdd, MulAdd
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd
using BlockArrays: BlockLayout
using ..SparseArrayInterface: SparseLayout
using LinearAlgebra: mul!

# TODO: Generalize to `BlockSparseArrayLike`.
function ArrayLayouts.MemoryLayout(arraytype::Type{<:AbstractBlockSparseArray})
function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
outer_layout = typeof(MemoryLayout(blockstype(arraytype)))
inner_layout = typeof(MemoryLayout(blocktype(arraytype)))
return BlockLayout{outer_layout,inner_layout}()
Expand All @@ -16,14 +15,9 @@ function Base.similar(
return similar(BlockSparseArray{elt}, axes)
end

function ArrayLayouts.materialize!(
m::MatMulMatAdd{
<:BlockLayout{<:SparseLayout},
<:BlockLayout{<:SparseLayout},
<:BlockLayout{<:SparseLayout},
},
)
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
mul!(a_dest, a1, a2, α, β)
# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
a_dest = BlockSparseArray{eltype(a)}(axes)
a_dest .= a
return a_dest
end
Original file line number Diff line number Diff line change
@@ -1,5 +1,48 @@
using BlockArrays: BlockedUnitRange, BlockSlice
using Base.Broadcast: Broadcast

function Broadcast.BroadcastStyle(arraytype::Type{<:BlockSparseArrayLike})
return BlockSparseArrayStyle{ndims(arraytype)}()
end

# Fix ambiguity error with `BlockArrays`.
function Broadcast.BroadcastStyle(
arraytype::Type{
<:SubArray{
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}},
},
},
)
return BlockSparseArrayStyle{ndims(arraytype)}()
end
function Broadcast.BroadcastStyle(
arraytype::Type{
<:SubArray{
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{
BlockSlice{<:Any,<:BlockedUnitRange},
BlockSlice{<:Any,<:BlockedUnitRange},
Vararg{Any},
},
},
},
)
return BlockSparseArrayStyle{ndims(arraytype)}()
end
function Broadcast.BroadcastStyle(
arraytype::Type{
<:SubArray{
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{Any,BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}},
},
},
)
return BlockSparseArrayStyle{ndims(arraytype)}()
end
Loading

0 comments on commit 957f2af

Please sign in to comment.