Skip to content

Commit

Permalink
Change behavior of non-blocked slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jun 4, 2024
1 parent c3eb3e7 commit d73ce01
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using ArrayLayouts: LayoutArray
using BlockArrays: blockisequal
using ..GradedAxes: blocked_getindex
using LinearAlgebra: Adjoint, Transpose
using ..SparseArrayInterface:
SparseArrayInterface,
Expand All @@ -26,23 +27,41 @@ end
# This is type piracy, try to avoid this, maybe requires defining `map`.
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)

function SparseArrayInterface.sparse_map!(
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
)
# Work around issue that:
#
# julia> using BlockArrays: blocks
#
# julia> blocks(randn(2, 2))[1, 1]
# 2×2 view(::Matrix{Float64}, BlockSlice(Block(1),Base.OneTo(2)), BlockSlice(Block(1),Base.OneTo(2))) with eltype Float64:
# 0.0534014 -1.1738
# -0.649799 0.128661
#
function blocks_getindex(a::AbstractArray{<:Any,N}, index::Vararg{Integer,N}) where {N}
return a[index...]
end
function blocks_getindex(
a::BlocksView{<:Any,N,<:Any,<:Array{<:Any,N}}, index::Vararg{Integer,N}
) where {N}
return a.array
end

function blocked_map!(f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray})
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
BI_dest = blockindexrange(a_dest, I)
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
# TODO: Investigate why this doesn't work:
# block_dest = @view a_dest[_block(BI_dest)]
block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...]
# TODO: Use `blocks_getindex`.
block_dest = blocks_getindex(blocks(a_dest), Int.(Tuple(_block(BI_dest)))...)
# TODO: Investigate why this doesn't work:
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
block_srcs = ntuple(length(a_srcs)) do i
return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
# TODO: Use `blocks_getindex`.
return blocks_getindex(blocks(a_srcs[i]), Int.(Tuple(_block(BI_srcs[i])))...)
end
subblock_dest = @view block_dest[BI_dest.indices...]
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
# TODO: Use `map!!` to handle immutable blocks.
# TODO: Use `map!!` to handle immutable blocks, such as FillArrays.
map!(f, subblock_dest, subblock_srcs...)
# Replace the entire block, handles initializing new blocks
# or if blocks are immutable.
Expand All @@ -51,6 +70,20 @@ function SparseArrayInterface.sparse_map!(
return a_dest
end

# Convert a non-block SubArray to a blocked subarray
# using the blocking of the underlying array.
to_blocked(a::AbstractArray) = a
function to_blocked(a::SubArray)
# Returns a `BlockedSubArray`.
return blocked_view(parent(a), parentindices(a)...)
end

function SparseArrayInterface.sparse_map!(
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
)
return blocked_map!(f, to_blocked.((a_dest, a_srcs...))...)
end

# TODO: Implement this.
# function SparseArrayInterface.sparse_mapreduce(::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray})
# end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ using SplitApplyCombine: groupcount

using Adapt: Adapt, WrappedArray

const WrappedAbstractBlockSparseArray{T,N} = WrappedArray{
T,N,AbstractBlockSparseArray,AbstractBlockSparseArray{T,N}
const WrappedAbstractBlockSparseArray{T,N} = Union{
WrappedArray{T,N,AbstractBlockSparseArray,AbstractBlockSparseArray{T,N}},
BlockedSubArray{T,N,<:AbstractBlockSparseArray{T,N}},
}

# TODO: Rename `AnyBlockSparseArray`.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BlockArrays:
AbstractBlockArray,
AbstractBlockVector,
Block,
BlockedUnitRange,
Expand Down Expand Up @@ -265,12 +266,13 @@ end

# Represents the array of arrays of a `SubArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`.
struct SparseSubArrayBlocks{T,N,Array<:SubArray{T,N}} <: AbstractSparseArray{T,N}
# TODO: Define `blockstype`.
struct SparseSubArrayBlocks{T,N,Array<:AbstractArray{T,N}} <: AbstractSparseArray{T,N}
array::Array
end
# TODO: Define this as `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
function blockrange(a::SparseSubArrayBlocks)
blockranges = blockrange.(axes(parent(a.array)), a.array.indices)
blockranges = blockrange.(axes(parent(a.array)), parentindices(a.array))
return map(blockrange -> Int.(blockrange), blockranges)
end
function Base.axes(a::SparseSubArrayBlocks)
Expand All @@ -284,16 +286,17 @@ function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where
parent_block = parent_blocks[I...]
# TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a)))
return @view parent_block[blockindices(parent(a.array), block, a.array.indices)...]
return @view parent_block[blockindices(parent(a.array), block, parentindices(a.array))...]
end
# TODO: This should be handled by generic `AbstractSparseArray` code.
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N}
return a[Tuple(I)...]
end
function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N}
parent_blocks = view(blocks(parent(a.array)), axes(a)...)
return parent_blocks[I...][blockindices(parent(a.array), Block(I), a.array.indices)...] =
value
return parent_blocks[I...][blockindices(
parent(a.array), Block(I), parentindices(a.array)
)...] = value
end
function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
if CartesianIndex(I) CartesianIndices(a)
Expand All @@ -313,10 +316,43 @@ function SparseArrayInterface.sparse_storage(a::SparseSubArrayBlocks)
return error("Not implemented")
end

function blocksparse_blocks(a::SubArray)
# An alternative to `SubArray` where the blocking
# is determined from the parent.
# See https://github.com/JuliaArrays/BlockArrays.jl/issues/347.
struct BlockedSubArray{T,N,P,I} <: AbstractBlockArray{T,N}
parent::P
indices::I
function BlockedSubArray(parent, indices)
return new{eltype(parent),ndims(parent),typeof(parent),typeof(indices)}(parent, indices)
end
end
Base.parent(a::BlockedSubArray) = getfield(a, :parent)
Base.parentindices(a::BlockedSubArray) = getfield(a, :indices)
to_subarray(a::BlockedSubArray) = view(parent(a), parentindices(a)...)
function Base.axes(a::BlockedSubArray)
return ntuple(ndims(a)) do dim
return only(axes(blocked_getindex(axes(parent(a), dim), parentindices(a)[dim])))
end
end
Base.size(a::BlockedSubArray) = map(length, axes(a))
function Base.getindex(a::BlockedSubArray{<:Any,N}, I::Vararg{Int,N}) where {N}
return to_subarray(a)[I...]
end

function blocked_view(
a::AbstractArray{<:Any,N}, indices::Vararg{AbstractUnitRange,N}
) where {N}
return BlockedSubArray(a, indices)
end

function blocksparse_blocks(a::BlockedSubArray)
return SparseSubArrayBlocks(a)
end

function blocksparse_blocks(a::SubArray)
return BlocksView(a)
end

using BlockArrays: BlocksView
# TODO: Is this correct in general?
SparseArrayInterface.nstored(a::BlocksView) = 1
14 changes: 14 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ using BlockArrays: block, blockindex
function blockedunitrange_getindices(
a::BlockedUnitRange, indices::AbstractUnitRange{<:Integer}
)
return indices
end

# TODO: Move this to a `BlockArraysExtensions` library.
# Slice a BlockedUnitRange, preserving the blocking.
# See https://github.com/JuliaArrays/BlockArrays.jl/issues/347.
function blocked_getindex(a::AbstractUnitRange, indices)
return a[indices]
end

# TODO: Move this to a `BlockArraysExtensions` library.
# Slice a BlockedUnitRange, preserving the blocking.
# See https://github.com/JuliaArrays/BlockArrays.jl/issues/347.
function blocked_getindex(a::BlockedUnitRange, indices::AbstractUnitRange{<:Integer})
first_blockindex = blockedunitrange_findblockindex(a, first(indices))
last_blockindex = blockedunitrange_findblockindex(a, last(indices))
first_block = block(first_blockindex)
Expand Down

0 comments on commit d73ce01

Please sign in to comment.