-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensors]
BlockSparseArrays
prototype (#1205)
- Loading branch information
Showing
16 changed files
with
422 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# BlockSparseArrays.jl | ||
|
||
A Julia `BlockSparseArray` type based on the `BlockArrays.jl` interface. | ||
|
||
It wraps an elementwise `SparseArray` type that uses a dictionary-of-keys | ||
to store non-zero values, specifically a `Dictionary` from `Dictionaries.jl`. | ||
`BlockArrays` reinterprets the `SparseArray` as a blocked data structure. | ||
|
||
````julia | ||
using NDTensors.BlockSparseArrays | ||
using BlockArrays: BlockArrays, blockedrange | ||
|
||
# Block dimensions | ||
i1 = [2, 3] | ||
i2 = [2, 3] | ||
|
||
i_axes = (blockedrange(i1), blockedrange(i2)) | ||
|
||
function block_size(axes, block) | ||
return length.(getindex.(axes, BlockArrays.Block.(block.n))) | ||
end | ||
|
||
# Data | ||
nz_blocks = BlockArrays.Block.([(1, 1), (2, 2)]) | ||
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks] | ||
nz_block_lengths = prod.(nz_block_sizes) | ||
|
||
# Blocks with discontiguous underlying data | ||
d_blocks = randn.(nz_block_sizes) | ||
|
||
# Blocks with contiguous underlying data | ||
# d_data = PseudoBlockVector(randn(sum(nz_block_lengths)), nz_block_lengths) | ||
# d_blocks = [reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for i in 1:length(nz_blocks)] | ||
|
||
B = BlockSparseArray(nz_blocks, d_blocks, i_axes) | ||
|
||
# Access a block | ||
B[BlockArrays.Block(1, 1)] | ||
|
||
# Access a non-zero block, returns a zero matrix | ||
B[BlockArrays.Block(1, 2)] | ||
|
||
# Set a zero block | ||
B[BlockArrays.Block(1, 2)] = randn(2, 3) | ||
|
||
# Matrix multiplication (not optimized for sparsity yet) | ||
B * B | ||
```` | ||
|
||
You can generate this README with: | ||
```julia | ||
using Literate | ||
Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor()) | ||
``` | ||
|
||
--- | ||
|
||
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# # BlockSparseArrays.jl | ||
# | ||
# A Julia `BlockSparseArray` type based on the `BlockArrays.jl` interface. | ||
# | ||
# It wraps an elementwise `SparseArray` type that uses a dictionary-of-keys | ||
# to store non-zero values, specifically a `Dictionary` from `Dictionaries.jl`. | ||
# `BlockArrays` reinterprets the `SparseArray` as a blocked data structure. | ||
|
||
using NDTensors.BlockSparseArrays | ||
using BlockArrays: BlockArrays, blockedrange | ||
|
||
## Block dimensions | ||
i1 = [2, 3] | ||
i2 = [2, 3] | ||
|
||
i_axes = (blockedrange(i1), blockedrange(i2)) | ||
|
||
function block_size(axes, block) | ||
return length.(getindex.(axes, BlockArrays.Block.(block.n))) | ||
end | ||
|
||
## Data | ||
nz_blocks = BlockArrays.Block.([(1, 1), (2, 2)]) | ||
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks] | ||
nz_block_lengths = prod.(nz_block_sizes) | ||
|
||
## Blocks with discontiguous underlying data | ||
d_blocks = randn.(nz_block_sizes) | ||
|
||
## Blocks with contiguous underlying data | ||
## d_data = PseudoBlockVector(randn(sum(nz_block_lengths)), nz_block_lengths) | ||
## d_blocks = [reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for i in 1:length(nz_blocks)] | ||
|
||
B = BlockSparseArray(nz_blocks, d_blocks, i_axes) | ||
|
||
## Access a block | ||
B[BlockArrays.Block(1, 1)] | ||
|
||
## Access a non-zero block, returns a zero matrix | ||
B[BlockArrays.Block(1, 2)] | ||
|
||
## Set a zero block | ||
B[BlockArrays.Block(1, 2)] = randn(2, 3) | ||
|
||
## Matrix multiplication (not optimized for sparsity yet) | ||
B * B | ||
|
||
# You can generate this README with: | ||
# ```julia | ||
# using Literate | ||
# Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor()) | ||
# ``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
module BlockSparseArrays | ||
using BlockArrays | ||
using Dictionaries | ||
|
||
using BlockArrays: block | ||
|
||
export BlockSparseArray, SparseArray | ||
|
||
include("sparsearray.jl") | ||
include("blocksparsearray.jl") | ||
|
||
end |
121 changes: 121 additions & 0 deletions
121
NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
using BlockArrays: block | ||
|
||
# Also add a version with contiguous underlying data. | ||
struct BlockSparseArray{ | ||
T,N,Blocks<:SparseArray{<:AbstractArray{T,N},N},Axes<:NTuple{N,AbstractUnitRange{Int}} | ||
} <: AbstractBlockArray{T,N} | ||
blocks::Blocks | ||
axes::Axes | ||
end | ||
|
||
# The size of a block | ||
function block_size(axes::Tuple, block::Block) | ||
return length.(getindex.(axes, Block.(block.n))) | ||
end | ||
|
||
struct BlockZero{Axes} | ||
axes::Axes | ||
end | ||
|
||
function (f::BlockZero)(T::Type, I::CartesianIndex) | ||
return fill!(T(undef, block_size(f.axes, Block(Tuple(I)))), false) | ||
end | ||
|
||
function BlockSparseArray( | ||
blocks::AbstractVector{<:Block{N}}, blockdata::AbstractVector, axes::NTuple{N} | ||
) where {N} | ||
return BlockSparseArray(Dictionary(blocks, blockdata), axes) | ||
end | ||
|
||
function BlockSparseArray( | ||
blockdata::Dictionary{<:Block{N}}, axes::NTuple{N,AbstractUnitRange{Int}} | ||
) where {N} | ||
blocks = keys(blockdata) | ||
cartesianblocks = map(block -> CartesianIndex(block.n), blocks) | ||
cartesiandata = Dictionary(cartesianblocks, blockdata) | ||
block_storage = SparseArray(cartesiandata, blocklength.(axes), BlockZero(axes)) | ||
return BlockSparseArray(block_storage, axes) | ||
end | ||
|
||
function BlockSparseArray( | ||
blockdata::Dictionary{<:Block{N}}, blockinds::NTuple{N,AbstractVector} | ||
) where {N} | ||
return BlockSparseArray(blockdata, blockedrange.(blockinds)) | ||
end | ||
|
||
Base.axes(block_arr::BlockSparseArray) = block_arr.axes | ||
|
||
function Base.copy(block_arr::BlockSparseArray) | ||
return BlockSparseArray(deepcopy(block_arr.blocks), copy.(block_arr.axes)) | ||
end | ||
|
||
function BlockArrays.viewblock(block_arr::BlockSparseArray, block) | ||
blks = block.n | ||
@boundscheck blockcheckbounds(block_arr, blks...) | ||
## block_size = length.(getindex.(axes(block_arr), Block.(blks))) | ||
# TODO: Make this `Zeros`? | ||
## zero = zeros(eltype(block_arr), block_size) | ||
return block_arr.blocks[blks...] # Fails because zero isn't defined | ||
## return get_nonzero(block_arr.blocks, blks, zero) | ||
end | ||
|
||
function Base.getindex(block_arr::BlockSparseArray{T,N}, bi::BlockIndex{N}) where {T,N} | ||
@boundscheck blockcheckbounds(block_arr, Block(bi.I)) | ||
bl = view(block_arr, block(bi)) | ||
inds = bi.α | ||
@boundscheck checkbounds(bl, inds...) | ||
v = bl[inds...] | ||
return v | ||
end | ||
|
||
function Base.setindex!( | ||
block_arr::BlockSparseArray{T,N}, v, i::Vararg{Integer,N} | ||
) where {T,N} | ||
@boundscheck checkbounds(block_arr, i...) | ||
block_indices = findblockindex.(axes(block_arr), i) | ||
block = map(block_index -> Block(block_index.I), block_indices) | ||
offsets = map(block_index -> only(block_index.α), block_indices) | ||
block_view = @view block_arr[block...] | ||
block_view[offsets...] = v | ||
block_arr[block...] = block_view | ||
return block_arr | ||
end | ||
|
||
function BlockArrays._check_setblock!( | ||
block_arr::BlockSparseArray{T,N}, v, block::NTuple{N,Integer} | ||
) where {T,N} | ||
for i in 1:N | ||
bsz = length(axes(block_arr, i)[Block(block[i])]) | ||
if size(v, i) != bsz | ||
throw( | ||
DimensionMismatch( | ||
string( | ||
"tried to assign $(size(v)) array to ", | ||
length.(getindex.(axes(block_arr), block)), | ||
" block", | ||
), | ||
), | ||
) | ||
end | ||
end | ||
end | ||
|
||
function Base.setindex!( | ||
block_arr::BlockSparseArray{T,N}, v, block::Vararg{Block{1},N} | ||
) where {T,N} | ||
blks = Int.(block) | ||
@boundscheck blockcheckbounds(block_arr, blks...) | ||
@boundscheck BlockArrays._check_setblock!(block_arr, v, blks) | ||
# This fails since it tries to replace the element | ||
block_arr.blocks[blks...] = v | ||
# Use .= here to overwrite data. | ||
## block_view = @view block_arr[Block(blks)] | ||
## block_view .= v | ||
return block_arr | ||
end | ||
|
||
function Base.getindex(block_arr::BlockSparseArray{T,N}, i::Vararg{Integer,N}) where {T,N} | ||
@boundscheck checkbounds(block_arr, i...) | ||
v = block_arr[findblockindex.(axes(block_arr), i)...] | ||
return v | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
struct SparseArray{T,N,Zero} <: AbstractArray{T,N} | ||
data::Dictionary{CartesianIndex{N},T} | ||
dims::NTuple{N,Int64} | ||
zero::Zero | ||
end | ||
|
||
Base.size(a::SparseArray) = a.dims | ||
|
||
function Base.setindex!(a::SparseArray{T,N}, v, I::CartesianIndex{N}) where {T,N} | ||
set!(a.data, I, v) | ||
return a | ||
end | ||
function Base.setindex!(a::SparseArray{T,N}, v, I::Vararg{Int,N}) where {T,N} | ||
return setindex!(a, v, CartesianIndex(I)) | ||
end | ||
|
||
function Base.getindex(a::SparseArray{T,N}, I::CartesianIndex{N}) where {T,N} | ||
return get(a.data, I, a.zero(T, I)) | ||
end | ||
function Base.getindex(a::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N} | ||
return getindex(a, CartesianIndex(I)) | ||
end | ||
|
||
## # `getindex` but uses a default if the value is | ||
## # structurally zero. | ||
## function get_nonzero(a::SparseArray{T,N}, I::CartesianIndex{N}, zero) where {T,N} | ||
## @boundscheck checkbounds(a, I) | ||
## return get(a.data, I, zero) | ||
## end | ||
## function get_nonzero(a::SparseArray{T,N}, I::NTuple{N,Int}, zero) where {T,N} | ||
## return get_nonzero(a, CartesianIndex(I), zero) | ||
## end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
using Test | ||
using NDTensors.BlockSparseArrays | ||
|
||
@testset "Test NDTensors.BlockSparseArrays" begin | ||
@testset "README" begin | ||
@test include( | ||
joinpath( | ||
pkgdir(BlockSparseArrays), "src", "BlockSparseArrays", "examples", "README.jl" | ||
), | ||
) isa Any | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# TODO: Implement. | ||
function contraction_output(tensor1::BlockSparseArray, tensor2::BlockSparseArray, indsR) | ||
return error("Not implemented") | ||
end | ||
|
||
# TODO: Implement. | ||
function contract!( | ||
tensorR::BlockSparseArray, | ||
labelsR, | ||
tensor1::BlockSparseArray, | ||
labels1, | ||
tensor2::BlockSparseArray, | ||
labels2, | ||
) | ||
return error("Not implemented") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
using Test | ||
using NDTensors | ||
|
||
include(joinpath(pkgdir(NDTensors), "src", "BlockSparseArrays", "test", "runtests.jl")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.