Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BlockSparseArrays] Redesign nested views #1504

Merged
merged 27 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4a22837
[WIP] [BlockSparseArrays] Redesign nested views
mtfishman Jun 19, 2024
116bce1
Merge branch 'main' into BlockSparseArray_redesign_nested_slicing
mtfishman Jun 20, 2024
cb319e6
Fix some tests
mtfishman Jun 21, 2024
d127189
Merge branch 'main' into BlockSparseArray_redesign_nested_slicing
mtfishman Jun 21, 2024
78d776b
Start rewriting view logic
mtfishman Jun 21, 2024
93ca493
Upgrade to new BlockSlice type in BlockArrays v1.1
mtfishman Jun 21, 2024
951f495
Start fixing slicing with unit ranges
mtfishman Jun 21, 2024
895dc80
Fix more tests
mtfishman Jun 21, 2024
7d7030b
Fix namespace issues
mtfishman Jun 21, 2024
0bd7598
Fix namespace issue
mtfishman Jun 21, 2024
3a7d602
Reorganize some tests
mtfishman Jun 22, 2024
fe32cd6
Some cleanup
mtfishman Jun 22, 2024
50b4511
Fix tests
mtfishman Jun 22, 2024
67afbc5
Start fixing unit range slicing of BlockIndices
mtfishman Jun 23, 2024
3817a6c
Merge branch 'main' into BlockSparseArray_redesign_nested_slicing
mtfishman Jun 23, 2024
01337c0
Merge branch 'main' into BlockSparseArray_redesign_nested_slicing
mtfishman Jun 24, 2024
26e2070
Slicing arrays with permuted blocks with unit ranges
mtfishman Jun 24, 2024
92d9d50
Fix some tests
mtfishman Jun 24, 2024
88e908b
Fix some tests, break some tests, the eternal cycle of life
mtfishman Jun 24, 2024
a44b776
Fixing tests
mtfishman Jun 24, 2024
9f94118
Fix more tests
mtfishman Jun 24, 2024
9de01eb
More broken tests
mtfishman Jun 24, 2024
f054f06
Fix more tests
mtfishman Jun 25, 2024
8b4bb7a
Unbrake some tests
mtfishman Jun 25, 2024
4873cab
Fix some more tests
mtfishman Jun 25, 2024
494345f
Fix tests
mtfishman Jun 25, 2024
28bd82a
[NDTensors] Bump to v0.3.36
mtfishman Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.35"
version = "0.3.36"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,6 @@ function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix})
return dual.(reverse(axes(a')))
end

# TODO: Delete this definition in favor of the one in
# GradedAxes once https://github.com/JuliaArrays/BlockArrays.jl/pull/405 is merged.
# TODO: Make a special definition for `BlockedVector{<:Block{1}}` in order
# to merge blocks.
function GradedAxes.blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
)
# Without converting `indices` to `Vector`,
# mapping `indices` outputs a `BlockVector`
# which is harder to reason about.
blocks = map(index -> a[index], Vector(indices))
# We pass `length.(blocks)` to `mortar` in order
# to pass block labels to the axes of the output,
# if they exist. This makes it so that
# `only(axes(a[indices])) isa `GradedUnitRange`
# if `a isa `GradedUnitRange`, for example.
# TODO: Remove `unlabel` once `BlockArray` axes
# type is generalized in BlockArrays.jl.
# TODO: Support using `BlockSparseVector`, need
# to make more `BlockSparseArray` constructors.
return BlockSparseArray(blocks, (blockedrange(length.(blocks)),))
end

# This definition is only needed since calls like
# `a[[Block(1), Block(2)]]` where `a isa AbstractGradedUnitRange`
# returns a `BlockSparseVector` instead of a `BlockVector`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
# TODO: Fix this for `BlockedArray`.
@test_broken nstored(b) == 256
@test nstored(b) == 256
# TODO: Fix this for `BlockedArray`.
@test_broken block_nstored(b) == 16
for i in 1:ndims(a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using BlockArrays:
AbstractBlockVector,
Block,
BlockRange,
BlockedOneTo,
BlockedUnitRange,
BlockVector,
BlockSlice,
Expand All @@ -19,19 +20,6 @@ using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
struct GenericBlockSlice{B,T<:Integer,I<:AbstractUnitRange{T}} <: AbstractUnitRange{T}
block::B
indices::I
end
BlockArrays.Block(bs::GenericBlockSlice{<:Block}) = bs.block
for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe_length)
@eval Base.$f(S::GenericBlockSlice) = Base.$f(S.indices)
end
Base.getindex(S::GenericBlockSlice, i::Integer) = getindex(S.indices, i)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
struct BlockIndices{B,T<:Integer,I<:AbstractVector{T}} <: AbstractVector{T}
Expand All @@ -42,6 +30,63 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
@eval Base.$f(S::BlockIndices) = Base.$f(S.indices)
end
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)
function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# It seems like this isn't handling the case where `i` is a
# subslice of a block correctly (i.e. it ignores `i.indices`).
@assert length(S.indices[Block(i)]) == length(i.indices)
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
end
function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# TODO: Turn this into a `blockedunitrange_getindices` definition.
subblocks = S.blocks[Int.(i.block)]
subindices = mortar(
map(1:length(i.block)) do I
r = blocks(i.indices)[I]
return S.indices[first(r)]:S.indices[last(r)]
end,
)
return BlockIndices(subblocks, subindices)
end

# TODO: This is type piracy. This is used in `reindex` when making
# views of blocks of sliced block arrays, for example:
# ```julia
# a = BlockSparseArray{elt}(undef, ([2, 3], [2, 3]))
# b = @view a[[Block(1)[1:1], Block(2)[1:2]], [Block(1)[1:1], Block(2)[1:2]]]
# b[Block(1, 1)]
# ```
# Without this change, BlockArrays has the slicing behavior:
# ```julia
# julia> mortar([Block(1)[1:1], Block(2)[1:2]])[BlockSlice(Block(2), 2:3)]
# 2-element Vector{BlockIndex{1, Tuple{Int64}, Tuple{Int64}}}:
# Block(2)[1]
# Block(2)[2]
# ```
# while with this change it has the slicing behavior:
# ```julia
# julia> mortar([Block(1)[1:1], Block(2)[1:2]])[BlockSlice(Block(2), 2:3)]
# Block(2)[1:2]
# ```
# i.e. it preserves the types of the blocks better. Upstream this fix to
# BlockArrays.jl. Also consider overloading `reindex` so that it calls
# a custom `getindex` function to avoid type piracy in the meantime.
# Also fix this in BlockArrays:
# ```julia
# julia> mortar([Block(1)[1:1], Block(2)[1:2]])[Block(2)]
# 2-element Vector{BlockIndex{1, Tuple{Int64}, Tuple{Int64}}}:
# Block(2)[1]
# Block(2)[2]
# ```
function Base.getindex(
a::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
I::BlockSlice{<:Block{1}},
)
# Check that the block slice corresponds to the correct block.
@assert I.indices == only(axes(a))[Block(I)]
return blocks(a)[Int(Block(I))]
end

# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices)
Expand Down Expand Up @@ -185,15 +230,12 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
return r
end

using BlockArrays: BlockSlice
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
function blockrange(axis::AbstractUnitRange, r::GenericBlockSlice)
using BlockArrays: BlockSlice
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
include("abstractblocksparsearray/abstractblocksparsevector.jl")
include("abstractblocksparsearray/view.jl")
include("abstractblocksparsearray/views.jl")
include("abstractblocksparsearray/arraylayouts.jl")
include("abstractblocksparsearray/sparsearrayinterface.jl")
include("abstractblocksparsearray/broadcast.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function Broadcast.BroadcastStyle(
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
<:Tuple{BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
},
},
)
Expand All @@ -25,8 +25,8 @@ function Broadcast.BroadcastStyle(
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},
Vararg{Any},
},
},
Expand All @@ -40,7 +40,7 @@ function Broadcast.BroadcastStyle(
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{Any,BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
<:Tuple{Any,BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
},
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ end

# TODO: Why isn't this calling `mapreduce` already?
function Base.iszero(a::BlockSparseArrayLike)
return sparse_iszero(a)
return sparse_iszero(blocks(a))
end

# TODO: Why isn't this calling `mapreduce` already?
function Base.isreal(a::BlockSparseArrayLike)
return sparse_isreal(a)
return sparse_isreal(blocks(a))
end
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ end
function SparseArrayInterface.sparse_storage(a::AbstractBlockSparseArray)
return BlockSparseStorage(a)
end

function SparseArrayInterface.nstored(a::BlockSparseArrayLike)
return sum(nstored, sparse_storage(blocks(a)); init=zero(Int))
end

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using BlockArrays: Block, BlockSlices

function blocksparse_view(a, I...)
return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...)
end

# These definitions circumvent some generic definitions in BlockArrays.jl:
# https://github.com/JuliaArrays/BlockArrays.jl/blob/master/src/views.jl
# which don't handle subslices of blocks properly.
function Base.view(
a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Block{N}
) where {N}
return blocksparse_view(a, I)
end
function Base.view(
a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Vararg{Block{1},N}
) where {N}
return blocksparse_view(a, I...)
end
function Base.view(
V::SubArray{<:Any,1,<:BlockSparseArrayLike,<:Tuple{BlockSlices}}, I::Block{1}
)
return blocksparse_view(a, I)
end
Loading
Loading