Skip to content

Commit

Permalink
Start implementing block merging
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jul 1, 2024
1 parent 60cda06 commit 710fd89
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ using BlockArrays:
BlockIndex,
BlockIndexRange,
BlockRange,
BlockSlice,
BlockVector,
BlockedOneTo,
BlockedUnitRange,
BlockVector,
BlockSlice,
BlockedVector,
block,
blockaxes,
blockedrange,
Expand All @@ -22,6 +23,26 @@ using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices

# A return type for `blocks(array)` when `array` isn't blocked.
# Represents a vector with just that single block.
struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
blocks_maybe_single(a) = blocks(a)
blocks_maybe_single(a::Array) = SingleBlockView(a)
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
@assert all(isone, index)
return a.array
end

# A wrapper around a potentially blocked array that is not blocked.
struct NonBlockedArray{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
Base.size(a::NonBlockedArray) = size(a.array)
Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = a.array[I...]
BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(a.array)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
struct BlockIndices{B,T<:Integer,I<:AbstractVector{T}} <: AbstractVector{T}
Expand All @@ -39,6 +60,21 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
@assert length(S.indices[Block(i)]) == length(i.indices)
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
end
function Base.getindex(
S::BlockIndices{<:AbstractBlockVector{<:Block{1}}}, i::BlockSlice{<:Block{1}}
)
# TODO: Check for conistency of indices.
# Calling `mortar` on the indices wraps the multiple blocks into a single
# block, since the result shouldn't be blocked.
## return BlockIndices(S.blocks[Block(i)], NonBlockedArray(S.indices[Block(i)]))
return NonBlockedArray(S.indices[Block(i)])
end
function Base.getindex(
S::BlockIndices{<:BlockedVector{<:Block{1},<:BlockRange{1}}}, i::BlockSlice{<:Block{1}}
)
return i
end

function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# TODO: Turn this into a `blockedunitrange_getindices` definition.
Expand Down Expand Up @@ -94,6 +130,13 @@ function to_blockindices(a::BlockedOneTo{<:Integer}, I::UnitRange{<:Integer})
)
end

# This handles non-blocked slices.
# For example:
# a = BlockSparseArray{Float64}([2, 2, 2, 2])
# I = BlockedVector(Block.(1:4), [2, 2])
# @views a[I][Block(1)]
to_blockindices(a::Base.OneTo{<:Integer}, I::UnitRange{<:Integer}) = I

# TODO: This is type piracy. This is used in `reindex` when making
# views of blocks of sliced block arrays, for example:
# ```julia
Expand Down Expand Up @@ -291,14 +334,17 @@ function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
return only(blockaxes(r))
end

# This handles changing the blocking, for example:
# This handles block merging:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector(Block.(1:4), [2, 2])
# I = BlockVector(Block.(1:4), [2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# a[I, I]
# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
# TODO: Probably this is incorrect and should be something like:
# return findblock(axis, first(r)):findblock(axis, last(r))
function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}})
for b in r
@assert b blockaxes(axis, 1)
end
return only(blockaxes(r))
end

Expand Down Expand Up @@ -337,6 +383,10 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice)
return only(blockaxes(axis))
end

function blockrange(axis::AbstractUnitRange, r::NonBlockedArray)
return Block(1):Block(1)
end

function blockrange(axis::AbstractUnitRange, r)
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@ end
# This is type piracy, try to avoid this, maybe requires defining `map`.
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)

struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
_blocks(a) = blocks(a)
_blocks(a::Array) = SingleBlockView(a)
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
@assert all(isone, index)
return a.array
end

reblock(a) = a
# If the blocking of the slice doesn't match the blocking of the
# parent array, reblock according to the blocking of the parent array.
Expand All @@ -57,11 +47,11 @@ function SparseArrayInterface.sparse_map!(
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
# TODO: Investigate why this doesn't work:
# block_dest = @view a_dest[_block(BI_dest)]
block_dest = _blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...]
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
# TODO: Investigate why this doesn't work:
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
block_srcs = ntuple(length(a_srcs)) do i
return _blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
end
subblock_dest = @view block_dest[BI_dest.indices...]
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BlockArrays: BlockArrays, Block, BlockIndexRange, blocklength, blocksize, viewblock
using BlockArrays:
BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock

# This splits `BlockIndexRange{N}` into
# `NTuple{N,BlockIndexRange{1}}`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,13 @@ end
# a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])]
# Permute and merge blocks.
# TODO: This isn't merging blocks yet, that needs to be implemented that.
function blocksparse_to_indices(a, inds, I::Tuple{BlockVector{<:Block{1}},Vararg{Any}})
function blocksparse_to_indices(
a, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}}
)
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

# TODO: Should this be combined with the version above?
function blocksparse_to_indices(a, inds, I::Tuple{BlockedVector{<:Block{1}},Vararg{Any}})
I1 = blockedunitrange_getindices(inds[1], I[1])
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

# TODO: Need to implement this!
function block_merge end

Expand Down
15 changes: 2 additions & 13 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,7 @@ include("TestBlockSparseArraysUtils.jl")
@views for b in [Block(1, 1), Block(2, 2)]
a[b] = randn(elt, size(a[b]))
end
for I in (
Block.(1:2),
[Block(1), Block(2)],
BlockVector([Block(1), Block(2)], [1, 1]),
# TODO: This should merge blocks.
BlockVector([Block(1), Block(2)], [2]),
)
for I in (Block.(1:2), [Block(1), Block(2)])
b = @view a[I, I]
for I in CartesianIndices(a)
@test b[I] == a[I]
Expand All @@ -439,12 +433,7 @@ include("TestBlockSparseArraysUtils.jl")
# TODO: Use `blocksizes(a)[Int.(Tuple(b))...]` once available.
a[b] = randn(elt, size(a[b]))
end
for I in (
[Block(2), Block(1)],
BlockVector([Block(2), Block(1)], [1, 1]),
# TODO: This should merge blocks.
BlockVector([Block(2), Block(1)], [2]),
)
for I in ([Block(2), Block(1)],)
b = @view a[I, I]
@test b[Block(1, 1)] == a[Block(2, 2)]
@test b[Block(2, 1)] == a[Block(1, 2)]
Expand Down
17 changes: 8 additions & 9 deletions NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using BlockArrays:
BlockArrays,
AbstractBlockVector,
AbstractBlockedUnitRange,
Block,
BlockIndex,
Expand Down Expand Up @@ -74,16 +75,14 @@ function blockedunitrange_getindices(
end

# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly.
# TODO: Make a special case for `BlockedVector{<:Block{1},<:BlockRange{1}}`?
# For example:
# ```julia
# blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices))
# return blockedrange(blocklengths)
# ```
function blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1},<:BlockRange{1}}
)
blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices))
return blockedrange(blocklengths)
end

# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly.
function blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1}}
a::AbstractBlockedUnitRange, indices::AbstractBlockVector{<:Block{1}}
)
return mortar(map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)))
end
Expand Down

0 comments on commit 710fd89

Please sign in to comment.