Skip to content

Commit

Permalink
[NDTensors] BlockSparseArrays prototype (#1205)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Oct 4, 2023
1 parent a9eb3cf commit 15decbd
Show file tree
Hide file tree
Showing 16 changed files with 422 additions and 41 deletions.
2 changes: 2 additions & 0 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.2.11"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Expand Down Expand Up @@ -36,6 +37,7 @@ NDTensorsTBLISExt = "TBLIS"

[compat]
Adapt = "3.5"
BlockArrays = "0.16"
Compat = "4.9"
Dictionaries = "0.3.5"
FLoops = "0.2.1"
Expand Down
59 changes: 59 additions & 0 deletions NDTensors/src/BlockSparseArrays/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# BlockSparseArrays.jl

A Julia `BlockSparseArray` type based on the `BlockArrays.jl` interface.

It wraps an elementwise `SparseArray` type that uses a dictionary-of-keys
to store non-zero values, specifically a `Dictionary` from `Dictionaries.jl`.
`BlockArrays` reinterprets the `SparseArray` as a blocked data structure.

````julia
using NDTensors.BlockSparseArrays
using BlockArrays: BlockArrays, blockedrange

# Block dimensions
i1 = [2, 3]
i2 = [2, 3]

i_axes = (blockedrange(i1), blockedrange(i2))

function block_size(axes, block)
return length.(getindex.(axes, BlockArrays.Block.(block.n)))
end

# Data
nz_blocks = BlockArrays.Block.([(1, 1), (2, 2)])
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
nz_block_lengths = prod.(nz_block_sizes)

# Blocks with discontiguous underlying data
d_blocks = randn.(nz_block_sizes)

# Blocks with contiguous underlying data
# d_data = PseudoBlockVector(randn(sum(nz_block_lengths)), nz_block_lengths)
# d_blocks = [reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for i in 1:length(nz_blocks)]

B = BlockSparseArray(nz_blocks, d_blocks, i_axes)

# Access a block
B[BlockArrays.Block(1, 1)]

# Access a non-zero block, returns a zero matrix
B[BlockArrays.Block(1, 2)]

# Set a zero block
B[BlockArrays.Block(1, 2)] = randn(2, 3)

# Matrix multiplication (not optimized for sparsity yet)
B * B
````

You can generate this README with:
```julia
using Literate
Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
```

---

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*

52 changes: 52 additions & 0 deletions NDTensors/src/BlockSparseArrays/examples/README.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# # BlockSparseArrays.jl
#
# A Julia `BlockSparseArray` type based on the `BlockArrays.jl` interface.
#
# It wraps an elementwise `SparseArray` type that uses a dictionary-of-keys
# to store non-zero values, specifically a `Dictionary` from `Dictionaries.jl`.
# `BlockArrays` reinterprets the `SparseArray` as a blocked data structure.

using NDTensors.BlockSparseArrays
using BlockArrays: BlockArrays, blockedrange

## Block dimensions
i1 = [2, 3]
i2 = [2, 3]

i_axes = (blockedrange(i1), blockedrange(i2))

function block_size(axes, block)
return length.(getindex.(axes, BlockArrays.Block.(block.n)))
end

## Data
nz_blocks = BlockArrays.Block.([(1, 1), (2, 2)])
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
nz_block_lengths = prod.(nz_block_sizes)

## Blocks with discontiguous underlying data
d_blocks = randn.(nz_block_sizes)

## Blocks with contiguous underlying data
## d_data = PseudoBlockVector(randn(sum(nz_block_lengths)), nz_block_lengths)
## d_blocks = [reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for i in 1:length(nz_blocks)]

B = BlockSparseArray(nz_blocks, d_blocks, i_axes)

## Access a block
B[BlockArrays.Block(1, 1)]

## Access a non-zero block, returns a zero matrix
B[BlockArrays.Block(1, 2)]

## Set a zero block
B[BlockArrays.Block(1, 2)] = randn(2, 3)

## Matrix multiplication (not optimized for sparsity yet)
B * B

# You can generate this README with:
# ```julia
# using Literate
# Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
# ```
12 changes: 12 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module BlockSparseArrays
using BlockArrays
using Dictionaries

using BlockArrays: block

export BlockSparseArray, SparseArray

include("sparsearray.jl")
include("blocksparsearray.jl")

end
121 changes: 121 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
using BlockArrays: block

# Also add a version with contiguous underlying data.
struct BlockSparseArray{
T,N,Blocks<:SparseArray{<:AbstractArray{T,N},N},Axes<:NTuple{N,AbstractUnitRange{Int}}
} <: AbstractBlockArray{T,N}
blocks::Blocks
axes::Axes
end

# The size of a block
function block_size(axes::Tuple, block::Block)
return length.(getindex.(axes, Block.(block.n)))
end

struct BlockZero{Axes}
axes::Axes
end

function (f::BlockZero)(T::Type, I::CartesianIndex)
return fill!(T(undef, block_size(f.axes, Block(Tuple(I)))), false)
end

function BlockSparseArray(
blocks::AbstractVector{<:Block{N}}, blockdata::AbstractVector, axes::NTuple{N}
) where {N}
return BlockSparseArray(Dictionary(blocks, blockdata), axes)
end

function BlockSparseArray(
blockdata::Dictionary{<:Block{N}}, axes::NTuple{N,AbstractUnitRange{Int}}
) where {N}
blocks = keys(blockdata)
cartesianblocks = map(block -> CartesianIndex(block.n), blocks)
cartesiandata = Dictionary(cartesianblocks, blockdata)
block_storage = SparseArray(cartesiandata, blocklength.(axes), BlockZero(axes))
return BlockSparseArray(block_storage, axes)
end

function BlockSparseArray(
blockdata::Dictionary{<:Block{N}}, blockinds::NTuple{N,AbstractVector}
) where {N}
return BlockSparseArray(blockdata, blockedrange.(blockinds))
end

Base.axes(block_arr::BlockSparseArray) = block_arr.axes

function Base.copy(block_arr::BlockSparseArray)
return BlockSparseArray(deepcopy(block_arr.blocks), copy.(block_arr.axes))
end

function BlockArrays.viewblock(block_arr::BlockSparseArray, block)
blks = block.n
@boundscheck blockcheckbounds(block_arr, blks...)
## block_size = length.(getindex.(axes(block_arr), Block.(blks)))
# TODO: Make this `Zeros`?
## zero = zeros(eltype(block_arr), block_size)
return block_arr.blocks[blks...] # Fails because zero isn't defined
## return get_nonzero(block_arr.blocks, blks, zero)
end

function Base.getindex(block_arr::BlockSparseArray{T,N}, bi::BlockIndex{N}) where {T,N}
@boundscheck blockcheckbounds(block_arr, Block(bi.I))
bl = view(block_arr, block(bi))
inds = bi.α
@boundscheck checkbounds(bl, inds...)
v = bl[inds...]
return v
end

function Base.setindex!(
block_arr::BlockSparseArray{T,N}, v, i::Vararg{Integer,N}
) where {T,N}
@boundscheck checkbounds(block_arr, i...)
block_indices = findblockindex.(axes(block_arr), i)
block = map(block_index -> Block(block_index.I), block_indices)
offsets = map(block_index -> only(block_index.α), block_indices)
block_view = @view block_arr[block...]
block_view[offsets...] = v
block_arr[block...] = block_view
return block_arr
end

function BlockArrays._check_setblock!(
block_arr::BlockSparseArray{T,N}, v, block::NTuple{N,Integer}
) where {T,N}
for i in 1:N
bsz = length(axes(block_arr, i)[Block(block[i])])
if size(v, i) != bsz
throw(
DimensionMismatch(
string(
"tried to assign $(size(v)) array to ",
length.(getindex.(axes(block_arr), block)),
" block",
),
),
)
end
end
end

function Base.setindex!(
block_arr::BlockSparseArray{T,N}, v, block::Vararg{Block{1},N}
) where {T,N}
blks = Int.(block)
@boundscheck blockcheckbounds(block_arr, blks...)
@boundscheck BlockArrays._check_setblock!(block_arr, v, blks)
# This fails since it tries to replace the element
block_arr.blocks[blks...] = v
# Use .= here to overwrite data.
## block_view = @view block_arr[Block(blks)]
## block_view .= v
return block_arr
end

function Base.getindex(block_arr::BlockSparseArray{T,N}, i::Vararg{Integer,N}) where {T,N}
@boundscheck checkbounds(block_arr, i...)
v = block_arr[findblockindex.(axes(block_arr), i)...]
return v
end
32 changes: 32 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/sparsearray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
data::Dictionary{CartesianIndex{N},T}
dims::NTuple{N,Int64}
zero::Zero
end

Base.size(a::SparseArray) = a.dims

function Base.setindex!(a::SparseArray{T,N}, v, I::CartesianIndex{N}) where {T,N}
set!(a.data, I, v)
return a
end
function Base.setindex!(a::SparseArray{T,N}, v, I::Vararg{Int,N}) where {T,N}
return setindex!(a, v, CartesianIndex(I))
end

function Base.getindex(a::SparseArray{T,N}, I::CartesianIndex{N}) where {T,N}
return get(a.data, I, a.zero(T, I))
end
function Base.getindex(a::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N}
return getindex(a, CartesianIndex(I))
end

## # `getindex` but uses a default if the value is
## # structurally zero.
## function get_nonzero(a::SparseArray{T,N}, I::CartesianIndex{N}, zero) where {T,N}
## @boundscheck checkbounds(a, I)
## return get(a.data, I, zero)
## end
## function get_nonzero(a::SparseArray{T,N}, I::NTuple{N,Int}, zero) where {T,N}
## return get_nonzero(a, CartesianIndex(I), zero)
## end
12 changes: 12 additions & 0 deletions NDTensors/src/BlockSparseArrays/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Test
using NDTensors.BlockSparseArrays

@testset "Test NDTensors.BlockSparseArrays" begin
@testset "README" begin
@test include(
joinpath(
pkgdir(BlockSparseArrays), "src", "BlockSparseArrays", "examples", "README.jl"
),
) isa Any
end
end
3 changes: 3 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ using TupleTools

include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("BlockSparseArrays/src/BlockSparseArrays.jl")
using .BlockSparseArrays
include("SmallVectors/src/SmallVectors.jl")
using .SmallVectors
include("SortedSets/src/SortedSets.jl")
Expand Down Expand Up @@ -122,6 +124,7 @@ include("empty/adapt.jl")
#
include("arraytensor/arraytensor.jl")
include("arraytensor/array.jl")
include("arraytensor/blocksparsearray.jl")

#####################################
# Deprecations
Expand Down
8 changes: 7 additions & 1 deletion NDTensors/src/arraytensor/arraytensor.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Used for dispatch to distinguish from Tensors wrapping TensorStorage.
# Remove once TensorStorage is removed.
const ArrayStorage{T,N} = Union{
Array{T,N},ReshapedArray{T,N},SubArray{T,N},PermutedDimsArray{T,N},StridedView{T,N}
Array{T,N},
ReshapedArray{T,N},
SubArray{T,N},
PermutedDimsArray{T,N},
StridedView{T,N},
BlockSparseArray{T,N},
}
const MatrixStorage{T} = Union{
ArrayStorage{T,2},
Expand Down Expand Up @@ -41,6 +46,7 @@ function setindex!(tensor::MatrixOrArrayStorageTensor, v, I::Integer...)
return tensor
end

# TODO: Just call `contraction_output(storage(tensor1), storage(tensor2), indsR)`
function contraction_output(
tensor1::MatrixOrArrayStorageTensor, tensor2::MatrixOrArrayStorageTensor, indsR
)
Expand Down
16 changes: 16 additions & 0 deletions NDTensors/src/arraytensor/blocksparsearray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# TODO: Implement.
function contraction_output(tensor1::BlockSparseArray, tensor2::BlockSparseArray, indsR)
return error("Not implemented")
end

# TODO: Implement.
function contract!(
tensorR::BlockSparseArray,
labelsR,
tensor1::BlockSparseArray,
labels1,
tensor2::BlockSparseArray,
labels2,
)
return error("Not implemented")
end
4 changes: 4 additions & 0 deletions NDTensors/test/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
using Test
using NDTensors

include(joinpath(pkgdir(NDTensors), "src", "BlockSparseArrays", "test", "runtests.jl"))
1 change: 1 addition & 0 deletions NDTensors/test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
Expand Down
Loading

0 comments on commit 15decbd

Please sign in to comment.