Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BlockSparseArrays] Permute and merge blocks #1514

Merged
merged 27 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e5c7fce
[BlockSparseArrays] Redesign block views again
mtfishman Jun 26, 2024
849051b
[NDTensors] Bump to v0.3.38
mtfishman Jun 26, 2024
4e7a280
[BlockSparseArrays] Merge blocks
mtfishman Jun 27, 2024
87b2ee1
Fix issues with GradedAxes
mtfishman Jun 27, 2024
31157d6
Merge branch 'BlockSparseArrays_redesign_blockviews_again' into Block…
mtfishman Jun 27, 2024
79bf2d4
Start to fix map
mtfishman Jun 27, 2024
1a22fd3
Initial version of block merging
mtfishman Jun 27, 2024
d322bf0
Merge branch 'main' into BlockSparseArrays_merge_blocks_2
mtfishman Jun 27, 2024
a3219fa
Fix some namespace issues
mtfishman Jun 27, 2024
e075825
Fix some tests
mtfishman Jun 28, 2024
e90dbec
Fix another test
mtfishman Jun 28, 2024
991fa07
Fix namespace issue
mtfishman Jun 28, 2024
ce44726
Output a more specific type when slicing with unit ranges
mtfishman Jun 28, 2024
17e274a
Fix some more tests
mtfishman Jun 30, 2024
29c5fbe
Fix some broken tests
mtfishman Jun 30, 2024
5bd85e3
Add tests
mtfishman Jun 30, 2024
60cda06
[NDTensors] Bump to v0.3.39
mtfishman Jun 30, 2024
710fd89
Start implementing block merging
mtfishman Jul 1, 2024
9dbac78
Slicing operation to permute and merge blocks
mtfishman Jul 1, 2024
b6ebf38
Preserve block labels during fusion
mtfishman Jul 1, 2024
88179e1
Fix some tests
mtfishman Jul 1, 2024
6d6b47f
Fix splitdims
mtfishman Jul 1, 2024
a1b0730
Fix for Julia 1.6
mtfishman Jul 1, 2024
7bfd24a
Fix typos
mtfishman Jul 1, 2024
da5abfa
Another fix for Julia 1.6
mtfishman Jul 1, 2024
a38ecdb
Another fix for Julia 1.6
mtfishman Jul 1, 2024
67bcd40
Another fix for Julia 1.6
mtfishman Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.38"
version = "0.3.39"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ function TensorAlgebra.splitdims(
return length(axis) ≤ length(axes(a, i))
end
blockperms = invblockperm.(blocksortperm.(axes_prod))
a_blockpermed = a[blockperms...]
# TODO: This is doing extra copies of the blocks,
# use `@view a[axes_prod...]` instead.
# That will require implementing some reindexing logic
# for this combination of slicing.
a_unblocked = a[axes_prod...]
a_blockpermed = a_unblocked[blockperms...]
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a = BlockSparseArray{elt}(d1, d2, d1, d2)
blockdiagonal!(randn!, a)
m = fusedims(a, (1, 2), (3, 4))
# TODO: Once block merging is implemented, this should
# be the real test.
for ax in axes(m)
@test ax isa GradedOneTo
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)]
@test blocklabels(ax) == [U1(0), U1(1), U1(1), U1(2)]
@test blocklabels(ax) == [U1(0), U1(1), U1(2)]
end
for I in CartesianIndices(m)
if I ∈ CartesianIndex.([(1, 1), (4, 4)])
Expand All @@ -105,10 +100,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
end
@test a[1, 1, 1, 1] == m[1, 1]
@test a[2, 2, 2, 2] == m[4, 4]
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocksize(m) == (3, 3)
@test blocksize(m) == (4, 4)
@test blocksize(m) == (3, 3)
@test a == splitdims(m, (d1, d2), (d1, d2))
end
@testset "dual axes" begin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ using BlockArrays:
AbstractBlockArray,
AbstractBlockVector,
Block,
BlockIndex,
BlockIndexRange,
BlockRange,
BlockSlice,
BlockVector,
BlockedOneTo,
BlockedUnitRange,
BlockVector,
BlockSlice,
BlockedVector,
block,
blockaxes,
blockedrange,
Expand All @@ -17,8 +20,30 @@ using BlockArrays:
findblockindex
using Compat: allequal
using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices
using ..GradedAxes: blockedunitrange_getindices, to_blockindices
using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices

# A return type for `blocks(array)` when `array` isn't blocked.
# Represents a vector with just that single block.
struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
blocks_maybe_single(a) = blocks(a)
blocks_maybe_single(a::Array) = SingleBlockView(a)
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
@assert all(isone, index)
return a.array
end

# A wrapper around a potentially blocked array that is not blocked.
struct NonBlockedArray{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
Base.size(a::NonBlockedArray) = size(a.array)
Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = a.array[I...]
BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(a.array)
const NonBlockedVector{T,Array} = NonBlockedArray{T,1,Array}
NonBlockedVector(array::AbstractVector) = NonBlockedArray(array)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
Expand All @@ -37,6 +62,43 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
@assert length(S.indices[Block(i)]) == length(i.indices)
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
end

# This is used in slicing like:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# a[I, I]
function Base.getindex(
S::BlockIndices{<:AbstractBlockVector{<:Block{1}}}, i::BlockSlice{<:Block{1}}
)
# TODO: Check for conistency of indices.
# Wrapping the indices in `NonBlockedVector` reinterprets the blocked indices
# as a single block, since the result shouldn't be blocked.
return NonBlockedVector(BlockIndices(S.blocks[Block(i)], S.indices[Block(i)]))
end
function Base.getindex(
S::BlockIndices{<:BlockedVector{<:Block{1},<:BlockRange{1}}}, i::BlockSlice{<:Block{1}}
)
return i
end

# Used in indexing such as:
# ```julia
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# b = @view a[I, I]
# @view b[Block(1, 1)[1:2, 2:2]]
# ```
# This is similar to the definition:
# blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})
function Base.getindex(
a::NonBlockedVector{<:Integer,<:BlockIndices}, I::UnitRange{<:Integer}
)
ax = only(axes(a.array.indices))
brs = to_blockindices(ax, I)
inds = blockedunitrange_getindices(ax, I)
return NonBlockedVector(a.array[BlockSlice(brs, inds)])
end

function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# TODO: Turn this into a `blockedunitrange_getindices` definition.
Expand All @@ -50,6 +112,34 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
return BlockIndices(subblocks, subindices)
end

# Used when performing slices like:
# @views a[[Block(2), Block(1)]][2:4, 2:4]
function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockVector{<:BlockIndex{1}}})
subblocks = mortar(
map(blocks(i.block)) do br
return S.blocks[Int(Block(br))][only(br.indices)]
end,
)
subindices = mortar(
map(blocks(i.block)) do br
S.indices[br]
end,
)
return BlockIndices(subblocks, subindices)
end

# Similar to the definition of `BlockArrays.BlockSlices`:
# ```julia
# const BlockSlices = Union{Base.Slice,BlockSlice{<:BlockRange{1}}}
# ```
# but includes `BlockIndices`, where the blocks aren't contiguous.
const BlockSliceCollection = Union{
Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}}
}
const SubBlockSliceCollection = BlockIndices{
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
}

# TODO: This is type piracy. This is used in `reindex` when making
# views of blocks of sliced block arrays, for example:
# ```julia
Expand Down Expand Up @@ -218,6 +308,12 @@ function blockrange(axis::AbstractUnitRange, r::UnitRange)
return findblock(axis, first(r)):findblock(axis, last(r))
end

# Occurs when slicing with `a[2:4, 2:4]`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedUnitRange{<:Integer})
# TODO: Check the blocks are commensurate.
return findblock(axis, first(r)):findblock(axis, last(r))
end

function blockrange(axis::AbstractUnitRange, r::Int)
## return findblock(axis, r)
return error("Slicing with integer values isn't supported.")
Expand All @@ -241,14 +337,17 @@ function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
return only(blockaxes(r))
end

# This handles changing the blocking, for example:
# This handles block merging:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector(Block.(1:4), [2, 2])
# I = BlockVector(Block.(1:4), [2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# a[I, I]
# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
# TODO: Probably this is incorrect and should be something like:
# return findblock(axis, first(r)):findblock(axis, last(r))
function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}})
for b in r
@assert b ∈ blockaxes(axis, 1)
end
return only(blockaxes(r))
end

Expand Down Expand Up @@ -287,6 +386,10 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice)
return only(blockaxes(axis))
end

function blockrange(axis::AbstractUnitRange, r::NonBlockedVector)
return Block(1):Block(1)
end

function blockrange(axis::AbstractUnitRange, r)
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end
Expand Down Expand Up @@ -423,7 +526,18 @@ function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) wher
return a
end

function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N}
function SparseArrayInterface.nstored(a::BlockView)
# TODO: Store whether or not the block is stored already as
# a Bool in `BlockView`.
I = CartesianIndex(Int.(a.block))
# TODO: Use `block_stored_indices`.
if I ∈ stored_indices(blocks(a.array))
return nstored(blocks(a.array)[I])
end
return 0
end

function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
return view!(a, Tuple(index)...)
end
function view!(a::AbstractArray{<:Any,N}, index::Vararg{Block{1},N}) where {N}
Expand Down
3 changes: 2 additions & 1 deletion NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module BlockSparseArrays
include("BlockArraysExtensions/BlockArraysExtensions.jl")
include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
include("blocksparsearrayinterface/linearalgebra.jl")
include("blocksparsearrayinterface/blockzero.jl")
include("blocksparsearrayinterface/broadcast.jl")
include("blocksparsearrayinterface/arraylayouts.jl")
include("blocksparsearrayinterface/views.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
Expand All @@ -15,7 +17,6 @@ include("abstractblocksparsearray/broadcast.jl")
include("abstractblocksparsearray/map.jl")
include("blocksparsearray/defaults.jl")
include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysExtensions/BlockArraysExtensions.jl")
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,57 @@ end
# This is type piracy, try to avoid this, maybe requires defining `map`.
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)

reblock(a) = a

# If the blocking of the slice doesn't match the blocking of the
# parent array, reblock according to the blocking of the parent array.
function reblock(
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}}
)
# TODO: This relies on the behavior that slicing a block sparse
# array with a UnitRange inherits the blocking of the underlying
# block sparse array, we might change that default behavior
# so this might become something like `@blocked parent(a)[...]`.
return @view parent(a)[UnitRange{Int}.(parentindices(a))...]
end

function reblock(
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}}
)
return @view parent(a)[map(I -> I.array, parentindices(a))...]
end

function reblock(
a::SubArray{
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}},
},
)
# Remove the blocking.
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
end

# TODO: Rewrite this so that it takes the blocking structure
# made by combining the blocking of the axes (i.e. the blocking that
# is used to determine `union_stored_blocked_cartesianindices(...)`).
# `reblock` is a partial solution to that, but a bit ad-hoc.
# TODO: Move to `blocksparsearrayinterface/map.jl`.
function SparseArrayInterface.sparse_map!(
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
)
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
BI_dest = blockindexrange(a_dest, I)
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
# TODO: Investigate why this doesn't work:
# block_dest = @view a_dest[_block(BI_dest)]
block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...]
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
# TODO: Investigate why this doesn't work:
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
block_srcs = ntuple(length(a_srcs)) do i
return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
end
subblock_dest = @view block_dest[BI_dest.indices...]
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
Expand Down
Loading
Loading