diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 1c058a8efa..bb2634ce14 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -215,7 +215,7 @@ end function blockrange( axis::AbstractUnitRange, - r::BlockVector{BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}}, + r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}}, ) return map(b -> Block(b), blocks(r)) end @@ -271,7 +271,7 @@ end function blockindices( a::AbstractUnitRange, b::Block, - r::BlockVector{BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}}, + r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}}, ) # TODO: Change to iterate over `BlockRange(r)` # once https://github.com/JuliaArrays/BlockArrays.jl/issues/404 diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 8b3d37bd36..8b51619f99 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -120,19 +120,18 @@ blocktype(a::BlockSparseArrayLike) = eltype(blocks(a)) blocktype(arraytype::Type{<:BlockSparseArrayLike}) = eltype(blockstype(arraytype)) using ArrayLayouts: ArrayLayouts -## function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::Vararg{Int,N}) where {N} -## return ArrayLayouts.layout_getindex(a, I...) -## end function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::CartesianIndices{N}) where {N} return ArrayLayouts.layout_getindex(a, I) end function Base.getindex( - a::BlockSparseArrayLike{<:Any,N}, I::Vararg{AbstractUnitRange,N} + a::BlockSparseArrayLike{<:Any,N}, I::Vararg{AbstractUnitRange{<:Integer},N} ) where {N} return ArrayLayouts.layout_getindex(a, I...) end # TODO: Define `AnyBlockSparseMatrix`. -function Base.getindex(a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitRange,2}) +function Base.getindex( + a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitRange{<:Integer},2} +) return ArrayLayouts.layout_getindex(a, I...) end @@ -199,7 +198,7 @@ end # Needed by `BlockArrays` matrix multiplication interface function Base.similar( - arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractUnitRange}} + arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}} ) return similar(arraytype, eltype(arraytype), axes) end @@ -210,37 +209,26 @@ end # Delete once we drop support for older versions of Julia. function Base.similar( arraytype::Type{<:BlockSparseArrayLike}, - axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}}, -) - return similar(arraytype, eltype(arraytype), axes) -end - -# Needed by `BlockArrays` matrix multiplication interface -# Fixes ambiguity error with `BlockArrays.jl`. -function Base.similar( - arraytype::Type{<:BlockSparseArrayLike}, - axes::Tuple{AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}}, + axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, ) return similar(arraytype, eltype(arraytype), axes) end -# Needed by `BlockArrays` matrix multiplication interface -# Fixes ambiguity error with `BlockArrays.jl`. +# Fixes ambiguity error with `BlockArrays`. function Base.similar( arraytype::Type{<:BlockSparseArrayLike}, - axes::Tuple{ - AbstractBlockedUnitRange,AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}} - }, + axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, ) return similar(arraytype, eltype(arraytype), axes) end -# Needed by `BlockArrays` matrix multiplication interface -# Fixes ambiguity error with `BlockArrays.jl`. +# Fixes ambiguity error with `BlockArrays`. function Base.similar( arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{ - AbstractUnitRange{Int},AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}} + AbstractUnitRange{<:Integer}, + AbstractBlockedUnitRange{<:Integer}, + Vararg{AbstractUnitRange{<:Integer}}, }, ) return similar(arraytype, eltype(arraytype), axes) @@ -248,7 +236,8 @@ end # Needed for disambiguation function Base.similar( - arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractBlockedUnitRange}} + arraytype::Type{<:BlockSparseArrayLike}, + axes::Tuple{Vararg{AbstractBlockedUnitRange{<:Integer}}}, ) return similar(arraytype, eltype(arraytype), axes) end @@ -256,7 +245,9 @@ end # Needed by `BlockArrays` matrix multiplication interface # TODO: Define a `blocksparse_similar` function. function Base.similar( - arraytype::Type{<:BlockSparseArrayLike}, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}} + arraytype::Type{<:BlockSparseArrayLike}, + elt::Type, + axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}}, ) # TODO: Make generic for GPU, maybe using `blocktype`. # TODO: For non-block axes this should output `Array`. @@ -265,7 +256,7 @@ end # TODO: Define a `blocksparse_similar` function. function Base.similar( - a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}} + a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}} ) # TODO: Make generic for GPU, maybe using `blocktype`. # TODO: For non-block axes this should output `Array`. @@ -277,7 +268,9 @@ end function Base.similar( a::BlockSparseArrayLike, elt::Type, - axes::Tuple{AbstractBlockedUnitRange,Vararg{AbstractBlockedUnitRange}}, + axes::Tuple{ + AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}} + }, ) # TODO: Make generic for GPU, maybe using `blocktype`. # TODO: For non-block axes this should output `Array`. @@ -289,13 +282,37 @@ end function Base.similar( a::BlockSparseArrayLike, elt::Type, - axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}}, + axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, ) # TODO: Make generic for GPU, maybe using `blocktype`. # TODO: For non-block axes this should output `Array`. return BlockSparseArray{elt}(undef, axes) end +# Fixes ambiguity error with `BlockArrays`. +function Base.similar( + a::BlockSparseArrayLike, + elt::Type, + axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, +) + # TODO: Make generic for GPU, maybe using `blocktype`. + # TODO: For non-block axes this should output `Array`. + return BlockSparseArray{elt}(undef, axes) +end + +# Fixes ambiguity errors with BlockArrays. +function Base.similar( + a::BlockSparseArrayLike, + elt::Type, + axes::Tuple{ + AbstractUnitRange{<:Integer}, + AbstractBlockedUnitRange{<:Integer}, + Vararg{AbstractUnitRange{<:Integer}}, + }, +) + return BlockSparseArray{elt}(undef, axes) +end + # TODO: Define a `blocksparse_similar` function. # Fixes ambiguity error with `StaticArrays`. function Base.similar( diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 1f381fff65..e844bcb548 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -22,7 +22,24 @@ using Test: @test, @test_broken, @test_throws, @testset include("TestBlockSparseArraysUtils.jl") @testset "BlockSparseArrays (eltype=$elt)" for elt in (Float32, Float64, ComplexF32, ComplexF64) - @testset "Broken" begin end + @testset "Broken" begin + a = BlockSparseArray{elt}([2, 3], [3, 4]) + @test_broken a[Block(1, 2)] .= 1 + + a = BlockSparseArray{elt}([2, 3], [3, 4]) + b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]] + @test_broken b[2:4, 2:4] + + a = BlockSparseArray{elt}([2, 3], [3, 4]) + b = @views a[Block(1, 1)][1:2, 1:1] + for i in parentindices(b) + @test_broken i isa BlockSlice{<:BlockIndexRange{1}} + end + + a = BlockSparseArray{elt}([2, 3], [3, 4]) + b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]] + @test_broken b[Block(1, 1)] = randn(3, 3) + end @testset "Basics" begin a = BlockSparseArray{elt}([2, 3], [2, 3]) @test a == BlockSparseArray{elt}(blockedrange([2, 3]), blockedrange([2, 3])) @@ -472,6 +489,26 @@ include("TestBlockSparseArraysUtils.jl") @test a[Block(2, 2)] == x @test b[Block(2, 2)] == x end + + a = BlockSparseArray{elt}([2, 3], [3, 4]) + b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]] + x = randn(elt, 3, 4) + b[Block(1, 1)] .= x + @test b[Block(1, 1)] == x + @test a[Block(2, 2)] == x + @test_throws DimensionMismatch b[Block(1, 1)] .= randn(2, 3) + + a = BlockSparseArray{elt}([2, 3], [3, 4]) + b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]] + for index in parentindices(@view(b[Block(1, 1)])) + @test index isa BlockSlice{<:Block{1}} + end + + a = BlockSparseArray{elt}([2, 3], [3, 4]) + b = @view a[Block(1, 1)[1:2, 1:1]] + for i in parentindices(b) + @test i isa BlockSlice{<:BlockIndexRange{1}} + end end @testset "view!" begin for blk in ((Block(2, 2),), (Block(2), Block(2)))