From d734e640a385ffa9157d5edd4786aea98033fc0b Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Mon, 1 Jul 2024 18:21:57 -0400 Subject: [PATCH] [BlockSparseArrays] Permute and merge blocks (#1514) * [BlockSparseArrays] Permute and merge blocks * [NDTensors] Bump to v0.3.39 --- NDTensors/Project.toml | 2 +- .../src/BlockSparseArraysGradedAxesExt.jl | 7 +- .../test/runtests.jl | 12 +- .../BlockArraysExtensions.jl | 134 +++++++++- .../src/BlockSparseArrays.jl | 3 +- .../src/abstractblocksparsearray/map.jl | 42 ++- .../src/abstractblocksparsearray/views.jl | 241 +++++++++++++++++- .../wrappedabstractblocksparsearray.jl | 21 +- .../blocksparsearrayinterface.jl | 54 ++-- .../src/blocksparsearrayinterface/views.jl | 3 + .../lib/BlockSparseArrays/test/test_basics.jl | 119 ++++++--- .../lib/GradedAxes/src/blockedunitrange.jl | 64 ++++- .../src/lib/GradedAxes/src/gradedunitrange.jl | 54 +++- .../src/lib/GradedAxes/src/unitrangedual.jl | 13 + 14 files changed, 673 insertions(+), 96 deletions(-) create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/views.jl diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index b831edc8a9..66be11c53d 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.3.38" +version = "0.3.39" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl index adf7b35de0..3a815d4ea2 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl @@ -66,7 +66,12 @@ function TensorAlgebra.splitdims( return length(axis) ≤ length(axes(a, i)) end blockperms = invblockperm.(blocksortperm.(axes_prod)) - a_blockpermed = a[blockperms...] + # TODO: This is doing extra copies of the blocks, + # use `@view a[axes_prod...]` instead. + # That will require implementing some reindexing logic + # for this combination of slicing. + a_unblocked = a[axes_prod...] + a_blockpermed = a_unblocked[blockperms...] return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...) end diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index a7a54fb9cb..38142b65f5 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -87,14 +87,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a = BlockSparseArray{elt}(d1, d2, d1, d2) blockdiagonal!(randn!, a) m = fusedims(a, (1, 2), (3, 4)) - # TODO: Once block merging is implemented, this should - # be the real test. for ax in axes(m) @test ax isa GradedOneTo - # TODO: Current `fusedims` doesn't merge - # common sectors, need to fix. - @test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)] - @test blocklabels(ax) == [U1(0), U1(1), U1(1), U1(2)] + @test blocklabels(ax) == [U1(0), U1(1), U1(2)] end for I in CartesianIndices(m) if I ∈ CartesianIndex.([(1, 1), (4, 4)]) @@ -105,10 +100,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end @test a[1, 1, 1, 1] == m[1, 1] @test a[2, 2, 2, 2] == m[4, 4] - # TODO: Current `fusedims` doesn't merge - # common sectors, need to fix. - @test_broken blocksize(m) == (3, 3) - @test blocksize(m) == (4, 4) + @test blocksize(m) == (3, 3) @test a == splitdims(m, (d1, d2), (d1, d2)) end @testset "dual axes" begin diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 499fd42089..7e7b503475 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -3,11 +3,14 @@ using BlockArrays: AbstractBlockArray, AbstractBlockVector, Block, + BlockIndex, + BlockIndexRange, BlockRange, + BlockSlice, + BlockVector, BlockedOneTo, BlockedUnitRange, - BlockVector, - BlockSlice, + BlockedVector, block, blockaxes, blockedrange, @@ -17,8 +20,30 @@ using BlockArrays: findblockindex using Compat: allequal using Dictionaries: Dictionary, Indices -using ..GradedAxes: blockedunitrange_getindices -using ..SparseArrayInterface: stored_indices +using ..GradedAxes: blockedunitrange_getindices, to_blockindices +using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices + +# A return type for `blocks(array)` when `array` isn't blocked. +# Represents a vector with just that single block. +struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N} + array::Array +end +blocks_maybe_single(a) = blocks(a) +blocks_maybe_single(a::Array) = SingleBlockView(a) +function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N} + @assert all(isone, index) + return a.array +end + +# A wrapper around a potentially blocked array that is not blocked. +struct NonBlockedArray{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N} + array::Array +end +Base.size(a::NonBlockedArray) = size(a.array) +Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = a.array[I...] +BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(a.array) +const NonBlockedVector{T,Array} = NonBlockedArray{T,1,Array} +NonBlockedVector(array::AbstractVector) = NonBlockedArray(array) # BlockIndices works around an issue that the indices of BlockSlice # are restricted to AbstractUnitRange{Int}. @@ -37,6 +62,43 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}}) @assert length(S.indices[Block(i)]) == length(i.indices) return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)]) end + +# This is used in slicing like: +# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2]) +# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) +# a[I, I] +function Base.getindex( + S::BlockIndices{<:AbstractBlockVector{<:Block{1}}}, i::BlockSlice{<:Block{1}} +) + # TODO: Check for conistency of indices. + # Wrapping the indices in `NonBlockedVector` reinterprets the blocked indices + # as a single block, since the result shouldn't be blocked. + return NonBlockedVector(BlockIndices(S.blocks[Block(i)], S.indices[Block(i)])) +end +function Base.getindex( + S::BlockIndices{<:BlockedVector{<:Block{1},<:BlockRange{1}}}, i::BlockSlice{<:Block{1}} +) + return i +end + +# Used in indexing such as: +# ```julia +# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2]) +# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) +# b = @view a[I, I] +# @view b[Block(1, 1)[1:2, 2:2]] +# ``` +# This is similar to the definition: +# blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}}) +function Base.getindex( + a::NonBlockedVector{<:Integer,<:BlockIndices}, I::UnitRange{<:Integer} +) + ax = only(axes(a.array.indices)) + brs = to_blockindices(ax, I) + inds = blockedunitrange_getindices(ax, I) + return NonBlockedVector(a.array[BlockSlice(brs, inds)]) +end + function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}}) # TODO: Check that `i.indices` is consistent with `S.indices`. # TODO: Turn this into a `blockedunitrange_getindices` definition. @@ -50,6 +112,34 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}}) return BlockIndices(subblocks, subindices) end +# Used when performing slices like: +# @views a[[Block(2), Block(1)]][2:4, 2:4] +function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockVector{<:BlockIndex{1}}}) + subblocks = mortar( + map(blocks(i.block)) do br + return S.blocks[Int(Block(br))][only(br.indices)] + end, + ) + subindices = mortar( + map(blocks(i.block)) do br + S.indices[br] + end, + ) + return BlockIndices(subblocks, subindices) +end + +# Similar to the definition of `BlockArrays.BlockSlices`: +# ```julia +# const BlockSlices = Union{Base.Slice,BlockSlice{<:BlockRange{1}}} +# ``` +# but includes `BlockIndices`, where the blocks aren't contiguous. +const BlockSliceCollection = Union{ + Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}} +} +const SubBlockSliceCollection = BlockIndices{ + <:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}} +} + # TODO: This is type piracy. This is used in `reindex` when making # views of blocks of sliced block arrays, for example: # ```julia @@ -218,6 +308,12 @@ function blockrange(axis::AbstractUnitRange, r::UnitRange) return findblock(axis, first(r)):findblock(axis, last(r)) end +# Occurs when slicing with `a[2:4, 2:4]`. +function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedUnitRange{<:Integer}) + # TODO: Check the blocks are commensurate. + return findblock(axis, first(r)):findblock(axis, last(r)) +end + function blockrange(axis::AbstractUnitRange, r::Int) ## return findblock(axis, r) return error("Slicing with integer values isn't supported.") @@ -241,14 +337,17 @@ function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer}) return only(blockaxes(r)) end -# This handles changing the blocking, for example: +# This handles block merging: # a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2]) +# I = BlockedVector(Block.(1:4), [2, 2]) +# I = BlockVector(Block.(1:4), [2, 2]) # I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) +# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) # a[I, I] -# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`. -function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer}) - # TODO: Probably this is incorrect and should be something like: - # return findblock(axis, first(r)):findblock(axis, last(r)) +function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}}) + for b in r + @assert b ∈ blockaxes(axis, 1) + end return only(blockaxes(r)) end @@ -287,6 +386,10 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice) return only(blockaxes(axis)) end +function blockrange(axis::AbstractUnitRange, r::NonBlockedVector) + return Block(1):Block(1) +end + function blockrange(axis::AbstractUnitRange, r) return error("Slicing not implemented for range of type `$(typeof(r))`.") end @@ -423,7 +526,18 @@ function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) wher return a end -function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N} +function SparseArrayInterface.nstored(a::BlockView) + # TODO: Store whether or not the block is stored already as + # a Bool in `BlockView`. + I = CartesianIndex(Int.(a.block)) + # TODO: Use `block_stored_indices`. + if I ∈ stored_indices(blocks(a.array)) + return nstored(blocks(a.array)[I]) + end + return 0 +end + +function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N} return view!(a, Tuple(index)...) end function view!(a::AbstractArray{<:Any,N}, index::Vararg{Block{1},N}) where {N} diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index a31d0269d5..3542c3e10b 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -1,9 +1,11 @@ module BlockSparseArrays +include("BlockArraysExtensions/BlockArraysExtensions.jl") include("blocksparsearrayinterface/blocksparsearrayinterface.jl") include("blocksparsearrayinterface/linearalgebra.jl") include("blocksparsearrayinterface/blockzero.jl") include("blocksparsearrayinterface/broadcast.jl") include("blocksparsearrayinterface/arraylayouts.jl") +include("blocksparsearrayinterface/views.jl") include("abstractblocksparsearray/abstractblocksparsearray.jl") include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl") include("abstractblocksparsearray/abstractblocksparsematrix.jl") @@ -15,7 +17,6 @@ include("abstractblocksparsearray/broadcast.jl") include("abstractblocksparsearray/map.jl") include("blocksparsearray/defaults.jl") include("blocksparsearray/blocksparsearray.jl") -include("BlockArraysExtensions/BlockArraysExtensions.jl") include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl") include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl") include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl") diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl index c376ff631a..b9ab510566 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -25,19 +25,57 @@ end # This is type piracy, try to avoid this, maybe requires defining `map`. ## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2) +reblock(a) = a + +# If the blocking of the slice doesn't match the blocking of the +# parent array, reblock according to the blocking of the parent array. +function reblock( + a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}} +) + # TODO: This relies on the behavior that slicing a block sparse + # array with a UnitRange inherits the blocking of the underlying + # block sparse array, we might change that default behavior + # so this might become something like `@blocked parent(a)[...]`. + return @view parent(a)[UnitRange{Int}.(parentindices(a))...] +end + +function reblock( + a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}} +) + return @view parent(a)[map(I -> I.array, parentindices(a))...] +end + +function reblock( + a::SubArray{ + <:Any, + <:Any, + <:AbstractBlockSparseArray, + <:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}}, + }, +) + # Remove the blocking. + return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...] +end + +# TODO: Rewrite this so that it takes the blocking structure +# made by combining the blocking of the axes (i.e. the blocking that +# is used to determine `union_stored_blocked_cartesianindices(...)`). +# `reblock` is a partial solution to that, but a bit ad-hoc. +# TODO: Move to `blocksparsearrayinterface/map.jl`. function SparseArrayInterface.sparse_map!( ::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray} ) + a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs) for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) BI_dest = blockindexrange(a_dest, I) BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs) # TODO: Investigate why this doesn't work: # block_dest = @view a_dest[_block(BI_dest)] - block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] + block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...] # TODO: Investigate why this doesn't work: # block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs)) block_srcs = ntuple(length(a_srcs)) do i - return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...] + return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...] end subblock_dest = @view block_dest[BI_dest.indices...] subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl index 456bb81827..e409ed5500 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl @@ -1,24 +1,48 @@ -using BlockArrays: BlockArrays, Block, BlockSlices, viewblock +using BlockArrays: + BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock -function blocksparse_view(a, I...) - return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...) +# This splits `BlockIndexRange{N}` into +# `NTuple{N,BlockIndexRange{1}}`. +# TODO: Move to `BlockArraysExtensions`. +to_tuple(x) = Tuple(x) +function to_tuple(x::BlockIndexRange{N}) where {N} + blocks = Tuple(Block(x)) + n = length(blocks) + return ntuple(dim -> blocks[dim][x.indices[dim]], n) +end + +# Override the default definition of `BlockArrays.blocksize`, +# which is incorrect for certain slices. +function BlockArrays.blocksize(a::SubArray{<:Any,<:Any,<:BlockSparseArrayLike}) + return blocklength.(axes(a)) +end +function BlockArrays.blocksize(a::SubArray{<:Any,<:Any,<:BlockSparseArrayLike}, i::Int) + # TODO: Maybe use `blocklength(axes(a, i))` which would be a bit faster. + return blocksize(a)[i] end # These definitions circumvent some generic definitions in BlockArrays.jl: # https://github.com/JuliaArrays/BlockArrays.jl/blob/master/src/views.jl # which don't handle subslices of blocks properly. function Base.view( - a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Block{N} + a::SubArray{ + <:Any,N,<:BlockSparseArrayLike,<:Tuple{Vararg{BlockSlice{<:BlockRange{1}},N}} + }, + I::Block{N}, ) where {N} return blocksparse_view(a, I) end function Base.view( - a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Vararg{Block{1},N} + a::SubArray{ + <:Any,N,<:BlockSparseArrayLike,<:Tuple{Vararg{BlockSlice{<:BlockRange{1}},N}} + }, + I::Vararg{Block{1},N}, ) where {N} return blocksparse_view(a, I...) end function Base.view( - V::SubArray{<:Any,1,<:BlockSparseArrayLike,<:Tuple{BlockSlices}}, I::Block{1} + V::SubArray{<:Any,1,<:BlockSparseArrayLike,<:Tuple{BlockSlice{<:BlockRange{1}}}}, + I::Block{1}, ) return blocksparse_view(a, I) end @@ -29,12 +53,217 @@ function BlockArrays.viewblock( ) where {N} return viewblock(a, Tuple(block)...) end + +# TODO: Define `blocksparse_viewblock`. function BlockArrays.viewblock( a::AbstractBlockSparseArray{<:Any,N}, block::Vararg{Block{1},N} ) where {N} I = CartesianIndex(Int.(block)) + # TODO: Use `block_stored_indices`. if I ∈ stored_indices(blocks(a)) return blocks(a)[I] end return BlockView(a, block) end + +# Specialized code for getting the view of a subblock. +function Base.view( + a::AbstractBlockSparseArray{<:Any,N}, block::BlockIndexRange{N} +) where {N} + return view(a, to_tuple(block)...) +end + +# Specialized code for getting the view of a subblock. +function Base.view( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N}}, I::BlockIndexRange{N} +) where {T,N} + return view(a, to_tuple(I)...) +end +function Base.view(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Block{1},N}) where {N} + return viewblock(a, I...) +end + +# TODO: Move to `GradedAxes` or `BlockArraysExtensions`. +to_block(I::Block{1}) = I +to_block(I::BlockIndexRange{1}) = Block(I) +to_block_indices(I::Block{1}) = Colon() +to_block_indices(I::BlockIndexRange{1}) = only(I.indices) + +function Base.view( + a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Union{Block{1},BlockIndexRange{1}},N} +) where {N} + return @views a[to_block.(I)...][to_block_indices.(I)...] +end + +function Base.view( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N}}, I::Vararg{Block{1},N} +) where {T,N} + return viewblock(a, I...) +end +function Base.view( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N}}, + I::Vararg{Union{Block{1},BlockIndexRange{1}},N}, +) where {T,N} + return @views a[to_block.(I)...][to_block_indices.(I)...] +end +# Generic fallback. +function BlockArrays.viewblock( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N}}, I::Vararg{Block{1},N} +) where {T,N} + return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...) +end + +function Base.view( + a::SubArray{ + T, + N, + <:AbstractBlockSparseArray{T,N}, + <:Tuple{Vararg{Union{BlockSliceCollection,SubBlockSliceCollection},N}}, + }, + block::Union{Block{N},BlockIndexRange{N}}, +) where {T,N} + return viewblock(a, block) +end +function Base.view( + a::SubArray{ + T, + N, + <:AbstractBlockSparseArray{T,N}, + <:Tuple{Vararg{Union{BlockSliceCollection,SubBlockSliceCollection},N}}, + }, + block::Vararg{Union{Block{1},BlockIndexRange{1}},N}, +) where {T,N} + return viewblock(a, block...) +end +function BlockArrays.viewblock( + a::SubArray{ + T, + N, + <:AbstractBlockSparseArray{T,N}, + <:Tuple{Vararg{Union{BlockSliceCollection,SubBlockSliceCollection},N}}, + }, + block::Union{Block{N},BlockIndexRange{N}}, +) where {T,N} + return viewblock(a, to_tuple(block)...) +end + +# Fixes ambiguity error with `BlockSparseArrayLike` definition. +function Base.view( + a::SubArray{ + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSlice{<:BlockRange{1}},N}} + }, + block::Block{N}, +) where {T,N} + return viewblock(a, block) +end +# Fixes ambiguity error with `BlockSparseArrayLike` definition. +function Base.view( + a::SubArray{ + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSlice{<:BlockRange{1}},N}} + }, + block::Vararg{Block{1},N}, +) where {T,N} + return viewblock(a, block...) +end + +# XXX: TODO: Distinguish if a sub-view of the block needs to be taken! +# Define a new `SubBlockSlice` which is used in: +# `blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})` +# in `blocksparsearrayinterface/blocksparsearrayinterface.jl`. +# TODO: Define `blocksparse_viewblock`. +function BlockArrays.viewblock( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSliceCollection,N}}}, + block::Vararg{Block{1},N}, +) where {T,N} + I = CartesianIndex(Int.(block)) + # TODO: Use `block_stored_indices`. + if I ∈ stored_indices(blocks(a)) + return blocks(a)[I] + end + return BlockView(parent(a), Block.(Base.reindex(parentindices(blocks(a)), Tuple(I)))) +end + +function to_blockindexrange( + a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}}, + I::Block{1}, +) + # TODO: Ideally we would just use `a.blocks[I]` but that doesn't + # work right now. + return blocks(a.blocks)[Int(I)] +end +function to_blockindexrange(a::Base.Slice{<:BlockedOneTo{<:Integer}}, I::Block{1}) + @assert I in only(blockaxes(a.indices)) + return I +end + +function BlockArrays.viewblock( + a::SubArray{ + T, + N, + <:AbstractBlockSparseArray{T,N}, + <:Tuple{Vararg{Union{BlockSliceCollection,SubBlockSliceCollection},N}}, + }, + block::Vararg{Block{1},N}, +) where {T,N} + brs = ntuple(dim -> to_blockindexrange(parentindices(a)[dim], block[dim]), ndims(a)) + return @view parent(a)[brs...] +end + +# TODO: Define `blocksparse_viewblock`. +function BlockArrays.viewblock( + a::SubArray{ + T, + N, + <:AbstractBlockSparseArray{T,N}, + <:Tuple{Vararg{Union{BlockSliceCollection,SubBlockSliceCollection},N}}, + }, + block::Vararg{BlockIndexRange{1},N}, +) where {T,N} + return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...) +end + +# Block slice of the result of slicing `@view a[2:5, 2:5]`. +# TODO: Move this to `BlockArraysExtensions`. +const BlockedSlice = BlockSlice{ + <:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}} +} + +function Base.view( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, + block::Union{Block{N},BlockIndexRange{N}}, +) where {T,N} + return viewblock(a, block) +end +function Base.view( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, + block::Vararg{Union{Block{1},BlockIndexRange{1}},N}, +) where {T,N} + return viewblock(a, block...) +end +function BlockArrays.viewblock( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, + block::Union{Block{N},BlockIndexRange{N}}, +) where {T,N} + return viewblock(a, to_tuple(block)...) +end +# TODO: Define `blocksparse_viewblock`. +function BlockArrays.viewblock( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, + I::Vararg{Block{1},N}, +) where {T,N} + # TODO: Use `reindex`, `to_indices`, etc. + brs = ntuple(ndims(a)) do dim + # TODO: Ideally we would use this but it outputs a Vector, + # not a range: + # return parentindices(a)[dim].block[I[dim]] + return blocks(parentindices(a)[dim].block)[Int(I[dim])] + end + return @view parent(a)[brs...] +end +# TODO: Define `blocksparse_viewblock`. +function BlockArrays.viewblock( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, + block::Vararg{BlockIndexRange{1},N}, +) where {T,N} + return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index d25affba75..076c3a3f6a 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -33,21 +33,30 @@ function Base.to_indices( return blocksparse_to_indices(a, inds, I) end -# a[[Block(1)[1:2], Block(2)[1:2]], [Block(1)[1:2], Block(2)[1:2]]] +# a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])] +# a[BlockedVector([Block(2), Block(1)], [2]), BlockedVector([Block(2), Block(1)], [2])] function Base.to_indices( - a::BlockSparseArrayLike, inds, I::Tuple{Vector{<:BlockIndexRange{1}},Vararg{Any}} + a::BlockSparseArrayLike, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}} ) - return to_indices(a, inds, (mortar(I[1]), Base.tail(I)...)) + return blocksparse_to_indices(a, inds, I) end -# a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])] -# a[BlockedVector([Block(2), Block(1)], [2]), BlockedVector([Block(2), Block(1)], [2])] +# a[mortar([Block(1)[1:2], Block(2)[1:3]])] function Base.to_indices( - a::BlockSparseArrayLike, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}} + a::BlockSparseArrayLike, + inds, + I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},Vararg{Any}}, ) return blocksparse_to_indices(a, inds, I) end +# a[[Block(1)[1:2], Block(2)[1:2]], [Block(1)[1:2], Block(2)[1:2]]] +function Base.to_indices( + a::BlockSparseArrayLike, inds, I::Tuple{Vector{<:BlockIndexRange{1}},Vararg{Any}} +) + return to_indices(a, inds, (mortar(I[1]), Base.tail(I)...)) +end + # BlockArrays `AbstractBlockArray` interface BlockArrays.blocks(a::BlockSparseArrayLike) = blocksparse_blocks(a) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index ee6790914b..182504e038 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -2,6 +2,8 @@ using BlockArrays: AbstractBlockVector, Block, BlockIndex, + BlockRange, + BlockSlice, BlockVector, BlockedUnitRange, BlockedVector, @@ -12,7 +14,6 @@ using BlockArrays: findblockindex using LinearAlgebra: Adjoint, Transpose using ..SparseArrayInterface: perm, iperm, nstored, sparse_zero! -## using MappedArrays: mappedarray blocksparse_blocks(a::AbstractArray) = error("Not implemented") @@ -28,28 +29,42 @@ end # https://github.com/JuliaArrays/BlockArrays.jl/issues/347 and also # https://github.com/ITensor/ITensors.jl/issues/1336. function blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}}) - bs1 = blockrange(inds[1], I[1]) + bs1 = to_blockindices(inds[1], I[1]) I1 = BlockSlice(bs1, blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end +# Special case when there is no blocking. +function blocksparse_to_indices( + a, + inds::Tuple{Base.OneTo{<:Integer},Vararg{Any}}, + I::Tuple{UnitRange{<:Integer},Vararg{Any}}, +) + return (inds[1][I[1]], to_indices(a, Base.tail(inds), Base.tail(I))...) +end + # a[[Block(2), Block(1)], [Block(2), Block(1)]] function blocksparse_to_indices(a, inds, I::Tuple{Vector{<:Block{1}},Vararg{Any}}) I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end -# a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])] -# Permute and merge blocks. -# TODO: This isn't merging blocks yet, that needs to be implemented that. -function blocksparse_to_indices(a, inds, I::Tuple{BlockVector{<:Block{1}},Vararg{Any}}) +# a[mortar([Block(1)[1:2], Block(2)[1:3]]), mortar([Block(1)[1:2], Block(2)[1:3]])] +# a[[Block(1)[1:2], Block(2)[1:3]], [Block(1)[1:2], Block(2)[1:3]]] +function blocksparse_to_indices( + a, inds, I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},Vararg{Any}} +) I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end -# TODO: Should this be combined with the version above? -function blocksparse_to_indices(a, inds, I::Tuple{BlockedVector{<:Block{1}},Vararg{Any}}) - I1 = blockedunitrange_getindices(inds[1], I[1]) +# a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])] +# Permute and merge blocks. +# TODO: This isn't merging blocks yet, that needs to be implemented that. +function blocksparse_to_indices( + a, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}} +) + I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -231,13 +246,13 @@ function Base.size(a::SparseSubArrayBlocks) end function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N} # TODO: Should this be defined as `@view a.array[Block(I)]` instead? - ## return @view a.array[Block(I)] + return @view a.array[Block(I)] - parent_blocks = @view blocks(parent(a.array))[blockrange(a)...] - parent_block = parent_blocks[I...] - # TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`. - block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a))) - return @view parent_block[blockindices(parent(a.array), block, a.array.indices)...] + ## parent_blocks = @view blocks(parent(a.array))[blockrange(a)...] + ## parent_block = parent_blocks[I...] + ## # TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`. + ## block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a))) + ## return @view parent_block[blockindices(parent(a.array), block, a.array.indices)...] end # TODO: This should be handled by generic `AbstractSparseArray` code. function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N} @@ -280,6 +295,15 @@ function blocksparse_blocks(a::SubArray) return SparseSubArrayBlocks(a) end +to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block) +to_blocks_indices(I::BlockIndices{<:Vector{<:Block{1}}}) = Int.(I.blocks) + +function blocksparse_blocks( + a::SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{BlockSliceCollection}}} +) + return @view blocks(parent(a))[map(to_blocks_indices, parentindices(a))...] +end + using BlockArrays: BlocksView # TODO: Is this correct in general? SparseArrayInterface.nstored(a::BlocksView) = 1 diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/views.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/views.jl new file mode 100644 index 0000000000..8e43f2625b --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/views.jl @@ -0,0 +1,3 @@ +function blocksparse_view(a, I...) + return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 05799afebc..10f8d6e35d 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -25,22 +25,25 @@ include("TestBlockSparseArraysUtils.jl") @testset "BlockSparseArrays (eltype=$elt)" for elt in (Float32, Float64, ComplexF32, ComplexF64) @testset "Broken" begin - a = BlockSparseArray{elt}([2, 2, 2, 2], [2, 2, 2, 2]) - @views for I in [Block(1, 1), Block(2, 2), Block(3, 3), Block(4, 4)] - a[I] = randn(elt, size(a[I])) - end - - I = blockedrange([4, 4]) - b = @view a[I, I] - @test_broken copy(b) + # TODO: Fix this and turn it into a proper test. + a = BlockSparseArray{elt}([2, 3], [2, 3]) + a[Block(1, 1)] = randn(elt, 2, 2) + a[Block(2, 2)] = randn(elt, 3, 3) + @test_broken a[:, 4] - I = BlockedVector(Block.(1:4), [2, 2]) - b = @view a[I, I] - @test_broken copy(b) + # TODO: Fix this and turn it into a proper test. + a = BlockSparseArray{elt}([2, 3], [2, 3]) + a[Block(1, 1)] = randn(elt, 2, 2) + a[Block(2, 2)] = randn(elt, 3, 3) + @test_broken a[:, [2, 4]] + @test_broken a[[3, 5], [2, 4]] - I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) - b = @view a[I, I] - @test_broken copy(b) + # TODO: Fix this and turn it into a proper test. + a = BlockSparseArray{elt}([2, 3], [2, 3]) + a[Block(1, 1)] = randn(elt, 2, 2) + a[Block(2, 2)] = randn(elt, 3, 3) + @test a[2:4, 4] == Array(a)[2:4, 4] + @test_broken a[4, 2:4] end @testset "Basics" begin a = BlockSparseArray{elt}([2, 3], [2, 3]) @@ -371,10 +374,10 @@ include("TestBlockSparseArraysUtils.jl") b = @view a[Block(2, 2)[1:2, 2:2]] @test size(b) == (2, 1) for i in parentindices(b) - @test i isa BlockSlice{<:BlockIndexRange{1}} + @test i isa UnitRange{Int} end - @test parentindices(b)[1] == BlockSlice(Block(2)[1:2], 3:4) - @test parentindices(b)[2] == BlockSlice(Block(2)[2:2], 5:5) + @test parentindices(b)[1] == 1:2 + @test parentindices(b)[2] == 2:2 a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) x = randn(elt, 1, 2) @@ -406,13 +409,7 @@ include("TestBlockSparseArraysUtils.jl") @views for b in [Block(1, 1), Block(2, 2)] a[b] = randn(elt, size(a[b])) end - for I in ( - Block.(1:2), - [Block(1), Block(2)], - BlockVector([Block(1), Block(2)], [1, 1]), - # TODO: This should merge blocks. - BlockVector([Block(1), Block(2)], [2]), - ) + for I in (Block.(1:2), [Block(1), Block(2)]) b = @view a[I, I] for I in CartesianIndices(a) @test b[I] == a[I] @@ -427,12 +424,7 @@ include("TestBlockSparseArraysUtils.jl") # TODO: Use `blocksizes(a)[Int.(Tuple(b))...]` once available. a[b] = randn(elt, size(a[b])) end - for I in ( - [Block(2), Block(1)], - BlockVector([Block(2), Block(1)], [1, 1]), - # TODO: This should merge blocks. - BlockVector([Block(2), Block(1)], [2]), - ) + for I in ([Block(2), Block(1)],) b = @view a[I, I] @test b[Block(1, 1)] == a[Block(2, 2)] @test b[Block(2, 1)] == a[Block(1, 2)] @@ -574,7 +566,7 @@ include("TestBlockSparseArraysUtils.jl") @test b isa SubArray{<:Any,<:Any,<:BlockSparseArray} @test block_nstored(b) == 1 @test b[Block(1, 1)] == x - @test @view(b[Block(1, 1)]) isa SubArray{<:Any,<:Any,<:BlockSparseArray} + @test @view(b[Block(1, 1)]) isa Matrix{elt} for blck in [Block(2, 1), Block(1, 2), Block(2, 2)] @test iszero(b[blck]) end @@ -625,14 +617,75 @@ include("TestBlockSparseArraysUtils.jl") a = BlockSparseArray{elt}([2, 3], [3, 4]) b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]] for index in parentindices(@view(b[Block(1, 1)])) - @test index isa BlockSlice{<:Block{1}} + @test index isa Base.OneTo{Int} end a = BlockSparseArray{elt}([2, 3], [3, 4]) + a[Block(1, 1)] = randn(elt, 2, 3) b = @view a[Block(1, 1)[1:2, 1:1]] + @test b isa SubArray{elt,2,Matrix{elt}} for i in parentindices(b) - @test i isa BlockSlice{<:BlockIndexRange{1}} + @test i isa UnitRange{Int} + end + + a = BlockSparseArray{elt}([2, 2, 2, 2], [2, 2, 2, 2]) + @views for I in [Block(1, 1), Block(2, 2), Block(3, 3), Block(4, 4)] + a[I] = randn(elt, size(a[I])) + end + for I in (blockedrange([4, 4]), BlockedVector(Block.(1:4), [2, 2])) + b = @view a[I, I] + @test copy(b) == a + @test blocksize(b) == (2, 2) + @test blocklengths.(axes(b)) == ([4, 4], [4, 4]) + @test b[Block(1, 1)] == a[Block.(1:2), Block.(1:2)] + @test b[Block(2, 1)] == a[Block.(3:4), Block.(1:2)] + @test b[Block(1, 2)] == a[Block.(1:2), Block.(3:4)] + @test b[Block(2, 2)] == a[Block.(3:4), Block.(3:4)] + c = @view b[Block(2, 2)] + @test blocksize(c) == (1, 1) + @test c == a[Block.(3:4), Block.(3:4)] + end + + a = BlockSparseArray{elt}([2, 3], [2, 3]) + a[Block(1, 1)] = randn(elt, 2, 2) + a[Block(2, 2)] = randn(elt, 3, 3) + for I in (mortar([Block(1)[2:2], Block(2)[2:3]]), [Block(1)[2:2], Block(2)[2:3]]) + b = @view a[:, I] + @test b == Array(a)[:, [2, 4, 5]] + end + + # Merge and permute blocks. + a = BlockSparseArray{elt}([2, 2, 2, 2], [2, 2, 2, 2]) + @views for I in [Block(1, 1), Block(2, 2), Block(3, 3), Block(4, 4)] + a[I] = randn(elt, size(a[I])) end + for I in ( + BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]), + BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]), + ) + b = @view a[I, I] + J = [Block(4), Block(3), Block(2), Block(1)] + @test b == a[J, J] + @test copy(b) == a[J, J] + @test blocksize(b) == (2, 2) + @test blocklengths.(axes(b)) == ([4, 4], [4, 4]) + @test b[Block(1, 1)] == Array(a)[[7, 8, 5, 6], [7, 8, 5, 6]] + c = @views b[Block(1, 1)][2:3, 2:3] + @test c == Array(a)[[8, 5], [8, 5]] + @test copy(c) == Array(a)[[8, 5], [8, 5]] + c = @view b[Block(1, 1)[2:3, 2:3]] + @test c == Array(a)[[8, 5], [8, 5]] + @test copy(c) == Array(a)[[8, 5], [8, 5]] + end + + # TODO: Add more tests of this, it may + # only be working accidentally. + a = BlockSparseArray{elt}([2, 3], [2, 3]) + a[Block(1, 1)] = randn(elt, 2, 2) + a[Block(2, 2)] = randn(elt, 3, 3) + @test a[2:4, 4] == Array(a)[2:4, 4] + # TODO: Fix this. + @test_broken a[4, 2:4] == Array(a)[4, 2:4] end @testset "view!" begin for blk in ((Block(2, 2),), (Block(2), Block(2))) diff --git a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl index 417a15adf9..883025df12 100644 --- a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl @@ -1,10 +1,14 @@ using BlockArrays: BlockArrays, + AbstractBlockVector, AbstractBlockedUnitRange, Block, + BlockIndex, BlockIndexRange, BlockRange, BlockSlice, + BlockVector, + BlockedOneTo, BlockedUnitRange, BlockedVector, block, @@ -72,18 +76,30 @@ function blockedunitrange_getindices( end # TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly. +# TODO: Make a special case for `BlockedVector{<:Block{1},<:BlockRange{1}}`? +# For example: +# ```julia +# blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices)) +# return blockedrange(blocklengths) +# ``` function blockedunitrange_getindices( - a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1},<:BlockRange{1}} + a::AbstractBlockedUnitRange, indices::AbstractBlockVector{<:Block{1}} ) - blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices)) - return blockedrange(blocklengths) -end - -# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly. -function blockedunitrange_getindices( - a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1}} -) - return mortar(map(bs -> mortar(map(b -> a[b], bs)), blocks(indices))) + blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)) + # We pass `length.(blks)` to `mortar` in order + # to pass block labels to the axes of the output, + # if they exist. This makes it so that + # `only(axes(a[indices])) isa `GradedUnitRange` + # if `a isa `GradedUnitRange`, for example. + # Note there is a more specialized definition: + # ```julia + # function blockedunitrange_getindices( + # a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}} + # ) + # ``` + # that does a better job of preserving labels, since `length` + # may drop labels for certain block types. + return mortar(blks, length.(blks)) end # TODO: Move this to a `BlockArraysExtensions` library. @@ -127,6 +143,13 @@ function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::Block return a[indices] end +function blockedunitrange_getindices( + a::AbstractBlockedUnitRange, + indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}, +) + return mortar(map(b -> a[b], blocks(indices))) +end + # TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices) return error("Not implemented.") @@ -140,3 +163,24 @@ end function _blocks(a::AbstractUnitRange, indices::BlockRange) return indices end + +# Slice `a` by `I`, returning a: +# `BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}` +# with the `BlockIndex{1}` corresponding to each value of `I`. +function to_blockindices(a::BlockedOneTo{<:Integer}, I::UnitRange{<:Integer}) + return mortar( + map(blocks(blockedunitrange_getindices(a, I))) do r + bi_first = findblockindex(a, first(r)) + bi_last = findblockindex(a, last(r)) + @assert block(bi_first) == block(bi_last) + return block(bi_first)[blockindex(bi_first):blockindex(bi_last)] + end, + ) +end + +# This handles non-blocked slices. +# For example: +# a = BlockSparseArray{Float64}([2, 2, 2, 2]) +# I = BlockedVector(Block.(1:4), [2, 2]) +# @views a[I][Block(1)] +to_blockindices(a::Base.OneTo{<:Integer}, I::UnitRange{<:Integer}) = I diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index f5c27b1c55..57e9420d88 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -18,7 +18,9 @@ using BlockArrays: findblock, findblockindex, mortar -using ..LabelledNumbers: LabelledNumbers, LabelledInteger, label, labelled, unlabel +using Compat: allequal +using ..LabelledNumbers: + LabelledNumbers, LabelledInteger, LabelledUnitRange, label, labelled, unlabel const AbstractGradedUnitRange{T<:LabelledInteger} = AbstractBlockedUnitRange{T} @@ -35,6 +37,29 @@ function Base.OrdinalRange{T,T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T} return unlabel_blocks(a) end +# This is only needed in certain Julia versions below 1.10 +# (for example Julia 1.6). +# TODO: Delete this once we drop Julia 1.6 support. +# The type constraint `T<:Integer` is needed to avoid an ambiguity +# error with a conversion method in Base. +function Base.UnitRange{T}( + a::AbstractGradedUnitRange{<:LabelledInteger{T}} +) where {T<:Integer} + return UnitRange(unlabel_blocks(a)) +end + +# This is only needed in certain Julia versions below 1.10 +# (for example Julia 1.6). +# TODO: Delete this once we drop Julia 1.6 support. +# The type constraint `T<:Integer` is needed to avoid an ambiguity +# error with a conversion method in Base. +using BlockArrays: BlockSlice +function Base.UnitRange{T}( + a::BlockSlice{<:Any,<:LabelledInteger{T},<:AbstractUnitRange{<:LabelledInteger{T}}} +) where {T<:Integer} + return UnitRange{T}(a.indices) +end + # TODO: See if this is needed. function Base.AbstractUnitRange{T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T} return unlabel_blocks(a) @@ -292,3 +317,30 @@ function BlockArrays.combine_blockaxes( ) where {T<:Integer} return BlockArrays.combine_blockaxes(a2, a1) end + +# Version of length that checks that all blocks have the same label +# and returns a labelled length with that label. +function labelled_length(a::AbstractBlockVector{<:Integer}) + blocklabels = label.(blocks(a)) + @assert allequal(blocklabels) + return labelled(unlabel(length(a)), first(blocklabels)) +end + +# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly. +# TODO: Make a special case for `BlockedVector{<:Block{1},<:BlockRange{1}}`? +# For example: +# ```julia +# blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices)) +# return blockedrange(blocklengths) +# ``` +function blockedunitrange_getindices( + a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}} +) + blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)) + # We pass `length.(blks)` to `mortar` in order + # to pass block labels to the axes of the output, + # if they exist. This makes it so that + # `only(axes(a[indices])) isa `GradedUnitRange` + # if `a isa `GradedUnitRange`, for example. + return mortar(blks, labelled_length.(blks)) +end diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index 495c90a239..aa04cc1600 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -51,6 +51,10 @@ function Base.getindex(a::UnitRangeDual, indices::Vector{<:BlockIndexRange{1}}) return unitrangedual_getindices_blocks(a, indices) end +function to_blockindices(a::UnitRangeDual, indices::UnitRange{<:Integer}) + return to_blockindices(nondual(a), indices) +end + Base.axes(a::UnitRangeDual) = axes(nondual(a)) using BlockArrays: BlockArrays, Block, BlockSlice @@ -104,3 +108,12 @@ function Base.OrdinalRange{Int,Int}( # return Int.(r) return unlabel(nondual(r)) end + +# This is only needed in certain Julia versions below 1.10 +# (for example Julia 1.6). +# TODO: Delete this once we drop Julia 1.6 support. +# The type constraint `T<:Integer` is needed to avoid an ambiguity +# error with a conversion method in Base. +function Base.UnitRange{T}(a::UnitRangeDual{<:LabelledInteger{T}}) where {T<:Integer} + return UnitRange{T}(nondual(a)) +end