From e0758259aeeb84a781814654cb50306379d156ee Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 28 Jun 2024 12:41:59 -0400 Subject: [PATCH] Fix some tests --- .../BlockArraysExtensions.jl | 13 ++++- .../src/BlockSparseArrays.jl | 2 +- .../src/abstractblocksparsearray/views.jl | 51 ++++++++++++------- .../blocksparsearrayinterface.jl | 13 ++--- .../lib/BlockSparseArrays/test/test_basics.jl | 8 ++- 5 files changed, 57 insertions(+), 30 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 499fd42089..c821048f06 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -3,6 +3,8 @@ using BlockArrays: AbstractBlockArray, AbstractBlockVector, Block, + BlockIndex, + BlockIndexRange, BlockRange, BlockedOneTo, BlockedUnitRange, @@ -50,6 +52,15 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}}) return BlockIndices(subblocks, subindices) end +# Similar to the definition of `BlockArrays.BlockSlices`: +# ```julia +# const BlockSlices = Union{Base.Slice,BlockSlice{<:BlockRange{1}}} +# ``` +# but includes `BlockIndices`, where the blocks aren't contiguous. +const BlockSliceCollection = Union{ + BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}} +} + # TODO: This is type piracy. This is used in `reindex` when making # views of blocks of sliced block arrays, for example: # ```julia @@ -423,7 +434,7 @@ function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) wher return a end -function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N} +function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N} return view!(a, Tuple(index)...) end function view!(a::AbstractArray{<:Any,N}, index::Vararg{Block{1},N}) where {N} diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index a31d0269d5..576cfb2d29 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -1,4 +1,5 @@ module BlockSparseArrays +include("BlockArraysExtensions/BlockArraysExtensions.jl") include("blocksparsearrayinterface/blocksparsearrayinterface.jl") include("blocksparsearrayinterface/linearalgebra.jl") include("blocksparsearrayinterface/blockzero.jl") @@ -15,7 +16,6 @@ include("abstractblocksparsearray/broadcast.jl") include("abstractblocksparsearray/map.jl") include("blocksparsearray/defaults.jl") include("blocksparsearray/blocksparsearray.jl") -include("BlockArraysExtensions/BlockArraysExtensions.jl") include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl") include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl") include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl") diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl index 23efb9c283..a321b49f25 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl @@ -1,4 +1,4 @@ -using BlockArrays: BlockArrays, Block, BlockSlices, viewblock +using BlockArrays: BlockArrays, Block, viewblock function blocksparse_view(a, I...) return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...) @@ -8,17 +8,24 @@ end # https://github.com/JuliaArrays/BlockArrays.jl/blob/master/src/views.jl # which don't handle subslices of blocks properly. function Base.view( - a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Block{N} + a::SubArray{ + <:Any,N,<:BlockSparseArrayLike,<:Tuple{Vararg{BlockSlice{<:BlockRange{1}},N}} + }, + I::Block{N}, ) where {N} return blocksparse_view(a, I) end function Base.view( - a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Vararg{Block{1},N} + a::SubArray{ + <:Any,N,<:BlockSparseArrayLike,<:Tuple{Vararg{BlockSlice{<:BlockRange{1}},N}} + }, + I::Vararg{Block{1},N}, ) where {N} return blocksparse_view(a, I...) end function Base.view( - V::SubArray{<:Any,1,<:BlockSparseArrayLike,<:Tuple{BlockSlices}}, I::Block{1} + V::SubArray{<:Any,1,<:BlockSparseArrayLike,<:Tuple{BlockSlice{<:BlockRange{1}}}}, + I::Block{1}, ) return blocksparse_view(a, I) end @@ -42,38 +49,46 @@ function BlockArrays.viewblock( end function Base.view( - a::SubArray{ - T, - N, - <:AbstractBlockSparseArray{T,N}, - <:Tuple{Vararg{BlockSlice{<:BlockRange{1,<:Tuple{<:AbstractUnitRange{<:Integer}}}},N}}, - }, + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSliceCollection,N}}}, block::Block{N}, ) where {T,N} return viewblock(a, block) end function Base.view( - a::SubArray{ - T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSlice{<:BlockRange{1}},N}} - }, + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSliceCollection,N}}}, block::Vararg{Block{1},N}, ) where {T,N} return viewblock(a, block...) end function BlockArrays.viewblock( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSliceCollection,N}}}, + block::Block{N}, +) where {T,N} + return viewblock(a, Tuple(block)...) +end + +# Fixes ambiguity error with `BlockSparseArrayLike` definition. +function Base.view( a::SubArray{ - T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSlice{<:BlockRange{1}}}} + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSlice{<:BlockRange{1}},N}} }, block::Block{N}, ) where {T,N} - return viewblock(a, Tuple(block)...) + return viewblock(a, block) +end +# Fixes ambiguity error with `BlockSparseArrayLike` definition. +function Base.view( + a::SubArray{ + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSlice{<:BlockRange{1}},N}} + }, + block::Vararg{Block{1},N}, +) where {T,N} + return viewblock(a, block...) end # TODO: Define `blocksparse_viewblock`. function BlockArrays.viewblock( - a::SubArray{ - T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSlice{<:BlockRange{1}}}} - }, + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSliceCollection,N}}}, block::Vararg{Block{1},N}, ) where {T,N} I = CartesianIndex(Int.(block)) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 35ab110c99..a53b39fc8f 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -14,7 +14,6 @@ using BlockArrays: findblockindex using LinearAlgebra: Adjoint, Transpose using ..SparseArrayInterface: perm, iperm, nstored, sparse_zero! -## using MappedArrays: mappedarray blocksparse_blocks(a::AbstractArray) = error("Not implemented") @@ -282,15 +281,13 @@ function blocksparse_blocks(a::SubArray) return SparseSubArrayBlocks(a) end +_blocks(I::BlockSlice) = I.block +_blocks(I::BlockIndices) = I.blocks + function blocksparse_blocks( - a::SubArray{ - <:Any, - <:Any, - <:Any, - <:Tuple{Vararg{BlockSlice{<:BlockRange{1,<:Tuple{<:AbstractUnitRange{<:Integer}}}}}}, - }, + a::SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{BlockSliceCollection}}} ) - return @view blocks(parent(a))[map(i -> Int.(i.block), parentindices(a))...] + return @view blocks(parent(a))[map(I -> Int.(_blocks(I)), parentindices(a))...] end using BlockArrays: BlocksView diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 05799afebc..7c7a5d6b76 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -32,15 +32,19 @@ include("TestBlockSparseArraysUtils.jl") I = blockedrange([4, 4]) b = @view a[I, I] - @test_broken copy(b) + @test copy(b) == a I = BlockedVector(Block.(1:4), [2, 2]) b = @view a[I, I] - @test_broken copy(b) + @test copy(b) == a I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) b = @view a[I, I] @test_broken copy(b) + + a = BlockSparseArray{elt}([2, 3], [2, 3]) + a[Block(1, 1)] = randn(elt, 2, 2) + @test_broken @view(a[Block(1, 1)[1:2, 2:2]]) isa SubArray{elt,2,Matrix{elt}} end @testset "Basics" begin a = BlockSparseArray{elt}([2, 3], [2, 3])