Skip to content

Commit

Permalink
[BlockSparseArrays] Function for fusing dimensions of `BlockSparseArr…
Browse files Browse the repository at this point in the history
…ay` (#1246)
  • Loading branch information
mtfishman authored Nov 10, 2023
1 parent 7eb2e30 commit 513fcf0
Show file tree
Hide file tree
Showing 13 changed files with 514 additions and 7 deletions.
115 changes: 115 additions & 0 deletions NDTensors/src/BlockSparseArrays/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,121 @@ end
main()
````

# BlockSparseArrays.jl and BlockArrays.jl interface

````julia
using NDTensors.BlockSparseArrays
using BlockArrays: BlockArrays

i1 = [2, 3]
i2 = [2, 3]
B = BlockSparseArray{Float64}(i1, i2)
B[BlockArrays.Block(1, 1)] = randn(2, 2)
B[BlockArrays.Block(2, 2)] = randn(3, 3)

# Minimal interface

# Specifies the block structure
@show collect.(BlockArrays.blockaxes(axes(B, 1)))

# Index range of a block
@show axes(B, 1)[BlockArrays.Block(1)]

# Last index of each block
@show BlockArrays.blocklasts(axes(B, 1))

# Find the block containing the index
@show BlockArrays.findblock(axes(B, 1), 3)

# Retrieve a block
@show B[BlockArrays.Block(1, 1)]
@show BlockArrays.viewblock(B, BlockArrays.Block(1, 1))

# Check block bounds
@show BlockArrays.blockcheckbounds(B, 2, 2)
@show BlockArrays.blockcheckbounds(B, BlockArrays.Block(2, 2))

# Derived interface

# Specifies the block structure
@show collect(Iterators.product(BlockArrays.blockaxes(B)...))

# Iterate over block views
@show sum.(BlockArrays.eachblock(B))

# Reshape into 1-d
@show BlockArrays.blockvec(B)[BlockArrays.Block(1)]

# Array-of-array view
@show BlockArrays.blocks(B)[1, 1] == B[BlockArrays.Block(1, 1)]

# Access an index within a block
@show B[BlockArrays.Block(1, 1)[1, 1]] == B[1, 1]
````

# Proposals for interfaces based on `BlockArrays.jl`, `SparseArrays`, and `BlockSparseArrays.jl`

```julia
# BlockSparseArray interface

# Define `eachblockindex`
eachblockindex(B::BlockArrays.AbstractBlockArray) = Iterators.product(BlockArrays.blockaxes(B)...)

eachblockindex(B::BlockArrays.AbstractBlockArray, b::Block) # indices in a block

blocksize(B::BlockArrays.AbstractBlockArray, b::Block) # size of a block
blocksize(axes, b::Block) # size of a block

blocklength(B::BlockArrays.AbstractBlockArray, b::Block) # length of a block
blocklength(axes, b::Block) # length of a block

# Other functions
BlockArrays.blocksize(B) # number of blocks in each dimension
BlockArrays.blocksizes(B) # length of blocks in each dimension

tuple_block(Block(2, 2)) == (Block(2), Block(2)) # Block.(b.n)
blocksize(axes, b::Block) = map(axis -> length(axis[Block(b.n)]), axes)
blocksize(B, Block(2, 2)) = size(B[Block(2, 2)]) # size of a specified block

# SparseArrays interface

findnz(S) # outputs nonzero keys and values (SparseArrayKit.nonzero_pairs)
nonzeros(S) # vector of structural nonzeros (SparseArrayKit.nonzero_values)
nnz(S) # number of nonzero values (SparseArrayKit.nonzero_length)
rowvals(S) # row that each nonzero value in `nonzeros(S)` is in
nzrange(S, c) # range of linear indices into `nonzeros(S)` for values in column `c`
findall(!iszero, S) # CartesianIndices of numerical nonzeros
issparse(S)
sparse(A) # convert to sparse
dropzeros!(S)
droptol!(S, tol)

# BlockSparseArrays.jl + SparseArrays

blockfindnz(B) # outputs nonzero block indices/keys and block views
blocknonzeros(B)
blocknnz(S)
blockfindall(!iszero, B)
isblocksparse(B)
blocksparse(A)
blockdropzeros!(B)
blockdroptol!(B, tol)

# SparseArrayKit.jl interface

nonzero_pairs(a) # SparseArrays.findnz
nonzero_keys(a) # SparseArrays.?
nonzero_values(a) # SparseArrays.nonzeros
nonzero_length(a) # SparseArrays.nnz

# BlockSparseArrays.jl + SparseArrayKit.jl interface

block_nonzero_pairs
block_nonzero_keys
block_nonzero_values
block_nonzero_length
```

You can generate this README with:
```julia
using Literate
Expand Down
115 changes: 115 additions & 0 deletions NDTensors/src/BlockSparseArrays/examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,121 @@ end

main()

# # BlockSparseArrays.jl and BlockArrays.jl interface

using NDTensors.BlockSparseArrays
using BlockArrays: BlockArrays

i1 = [2, 3]
i2 = [2, 3]
B = BlockSparseArray{Float64}(i1, i2)
B[BlockArrays.Block(1, 1)] = randn(2, 2)
B[BlockArrays.Block(2, 2)] = randn(3, 3)

## Minimal interface

## Specifies the block structure
@show collect.(BlockArrays.blockaxes(axes(B, 1)))

## Index range of a block
@show axes(B, 1)[BlockArrays.Block(1)]

## Last index of each block
@show BlockArrays.blocklasts(axes(B, 1))

## Find the block containing the index
@show BlockArrays.findblock(axes(B, 1), 3)

## Retrieve a block
@show B[BlockArrays.Block(1, 1)]
@show BlockArrays.viewblock(B, BlockArrays.Block(1, 1))

## Check block bounds
@show BlockArrays.blockcheckbounds(B, 2, 2)
@show BlockArrays.blockcheckbounds(B, BlockArrays.Block(2, 2))

## Derived interface

## Specifies the block structure
@show collect(Iterators.product(BlockArrays.blockaxes(B)...))

## Iterate over block views
@show sum.(BlockArrays.eachblock(B))

## Reshape into 1-d
@show BlockArrays.blockvec(B)[BlockArrays.Block(1)]

## Array-of-array view
@show BlockArrays.blocks(B)[1, 1] == B[BlockArrays.Block(1, 1)]

## Access an index within a block
@show B[BlockArrays.Block(1, 1)[1, 1]] == B[1, 1]

# # Proposals for interfaces based on `BlockArrays.jl`, `SparseArrays`, and `BlockSparseArrays.jl`

#=
```julia
# BlockSparseArray interface
# Define `eachblockindex`
eachblockindex(B::BlockArrays.AbstractBlockArray) = Iterators.product(BlockArrays.blockaxes(B)...)
eachblockindex(B::BlockArrays.AbstractBlockArray, b::Block) # indices in a block
blocksize(B::BlockArrays.AbstractBlockArray, b::Block) # size of a block
blocksize(axes, b::Block) # size of a block
blocklength(B::BlockArrays.AbstractBlockArray, b::Block) # length of a block
blocklength(axes, b::Block) # length of a block
# Other functions
BlockArrays.blocksize(B) # number of blocks in each dimension
BlockArrays.blocksizes(B) # length of blocks in each dimension
tuple_block(Block(2, 2)) == (Block(2), Block(2)) # Block.(b.n)
blocksize(axes, b::Block) = map(axis -> length(axis[Block(b.n)]), axes)
blocksize(B, Block(2, 2)) = size(B[Block(2, 2)]) # size of a specified block
# SparseArrays interface
findnz(S) # outputs nonzero keys and values (SparseArrayKit.nonzero_pairs)
nonzeros(S) # vector of structural nonzeros (SparseArrayKit.nonzero_values)
nnz(S) # number of nonzero values (SparseArrayKit.nonzero_length)
rowvals(S) # row that each nonzero value in `nonzeros(S)` is in
nzrange(S, c) # range of linear indices into `nonzeros(S)` for values in column `c`
findall(!iszero, S) # CartesianIndices of numerical nonzeros
issparse(S)
sparse(A) # convert to sparse
dropzeros!(S)
droptol!(S, tol)
# BlockSparseArrays.jl + SparseArrays
blockfindnz(B) # outputs nonzero block indices/keys and block views
blocknonzeros(B)
blocknnz(S)
blockfindall(!iszero, B)
isblocksparse(B)
blocksparse(A)
blockdropzeros!(B)
blockdroptol!(B, tol)
# SparseArrayKit.jl interface
nonzero_pairs(a) # SparseArrays.findnz
nonzero_keys(a) # SparseArrays.?
nonzero_values(a) # SparseArrays.nonzeros
nonzero_length(a) # SparseArrays.nnz
# BlockSparseArrays.jl + SparseArrayKit.jl interface
block_nonzero_pairs
block_nonzero_keys
block_nonzero_values
block_nonzero_length
```
=#

#=
You can generate this README with:
```julia
Expand Down
8 changes: 8 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@ module BlockSparseArrays
using BlockArrays
using Compat
using Dictionaries
using SplitApplyCombine

using BlockArrays: block

export BlockSparseArray, SparseArray

include("tensor_product.jl")
include("base.jl")
include("axes.jl")
include("abstractarray.jl")
include("permuteddimsarray.jl")
include("blockarrays.jl")
include("sparsearray.jl")
include("blocksparsearray.jl")
include("subarray.jl")
include("broadcast.jl")
include("fusedims.jl")
include("gradedrange.jl")

end
19 changes: 19 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/axes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# TODO: Delete
## function NDTensors.outer(s1::Base.OneTo, s2::Base.OneTo)
## return Base.OneTo(length(s1) * length(s2))
## end

function blockmerge(s::Base.OneTo, grouped_perm::Vector{Vector{Int}})
@assert grouped_perm == [[1]]
return s
end

blockmergesortperm(s::Base.OneTo) = [[1]]

function sub_axis(a::BlockedUnitRange, blocks)
return blockedrange([length(a[b]) for b in blocks])
end

function sub_axes(axes_src::Tuple, axes_parent::Tuple)
return sub_axis.(axes_src, axes_parent)
end
14 changes: 14 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/base.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function groupsorted(v)
return groupcount(identity, v)
end

# Get the permutation for sorting, then group by common elements.
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
function groupsortperm(v)
perm = sortperm(v)
v_sorted = @view v[perm]
group_lengths = groupsorted(v_sorted)
return blocks(BlockVector(perm, collect(group_lengths)))
end

tuple_cat(ts::Tuple...) = reduce((x, y) -> (x..., y...), ts)
3 changes: 3 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/blockarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Extensions to BlockArrays.jl
blocktuple(b::Block) = Block.(b.n)
inttuple(b::Block) = b.n
17 changes: 12 additions & 5 deletions NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ struct BlockSparseArray{
end

Base.axes(block_arr::BlockSparseArray) = block_arr.axes
blocks(a::BlockSparseArray) = a.blocks
BlockArrays.blocks(a::BlockSparseArray) = a.blocks
# TODO: Use `SetParameters`.
blocktype(a::BlockSparseArray{<:Any,<:Any,A}) where {A} = A

# TODO: Use `SetParameters`.
set_ndims(::Type{<:Array{T}}, n) where {T} = Array{T,n}

function nonzero_blockkeys(a::BlockSparseArray)
return map(Block Tuple, collect(nonzero_keys(blocks(a))))
end

function Base.reshape(a::BlockSparseArray, ax::Tuple{Vararg{AbstractUnitRange}})
## TODO: Use `SparseArray` reshape in some way?
## blocks_reshaped = reshape(blocks(a), blocklength.(ax))
Expand Down Expand Up @@ -43,7 +47,7 @@ end

# The size of a block
function block_size(axes::Tuple{Vararg{AbstractUnitRange}}, block::Block)
return length.(getindex.(axes, Block.(block.n)))
return length.(getindex.(axes, blocktuple(block)))
end

# The size of a block
Expand Down Expand Up @@ -89,6 +93,7 @@ end
function BlockSparseArray{T,N,B}(
::UndefInitializer, axes::Tuple{Vararg{Any,N}}
) where {T,N,B}
# TODO: `Block{N,Int}`?
blocks = Vector{Block{N}}()
return BlockSparseArray{T,N,B}(undef, blocks, axes)
end
Expand All @@ -100,7 +105,7 @@ function BlockSparseArray(
cartesianblocks = if isempty(blockdata)
Dictionary{Block{N},CartesianIndex{N}}()
else
map(block -> CartesianIndex(block.n), blocks)
map(block -> CartesianIndex(inttuple(block)), blocks)
end
cartesiandata = Dictionary(cartesianblocks, blockdata)
block_storage = SparseArray(cartesiandata, blocklength.(axes), BlockZero(axes))
Expand Down Expand Up @@ -151,7 +156,7 @@ function Base.copy(block_arr::BlockSparseArray)
end

function BlockArrays.viewblock(block_arr::BlockSparseArray, block)
blks = block.n
blks = inttuple(block)
@boundscheck blockcheckbounds(block_arr, blks...)
## block_size = length.(getindex.(axes(block_arr), Block.(blks)))
# TODO: Make this `Zeros`?
Expand Down Expand Up @@ -241,7 +246,9 @@ function Base.permutedims(a::BlockSparseArray, perm)
end

# TODO: Make `PermutedBlockSparseArray`.
function blocks(a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:BlockSparseArray})
function BlockArrays.blocks(
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:BlockSparseArray}
)
return PermutedDimsArray(blocks(parent(a)), perm(a))
end

Expand Down
Loading

0 comments on commit 513fcf0

Please sign in to comment.