diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 474e354820..1970127923 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -189,7 +189,9 @@ function Base.similar( end function blocksparse_similar(a, elt::Type, axes::Tuple) - return BlockSparseArray{elt,length(axes),similartype(blocktype(a), axes)}(undef, axes) + return BlockSparseArray{elt,length(axes),similartype(blocktype(a), elt, axes)}( + undef, axes + ) end # Needed by `BlockArrays` matrix multiplication interface diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 732a895286..6b1e760d43 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -125,12 +125,13 @@ _getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), i # Represents the array of arrays of a `PermutedDimsArray` # wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`. -struct SparsePermutedDimsArrayBlocks{T,N,Array<:PermutedDimsArray{T,N}} <: - AbstractSparseArray{T,N} +struct SparsePermutedDimsArrayBlocks{ + T,N,BlockType<:AbstractArray{T,N},Array<:PermutedDimsArray{T,N} +} <: AbstractSparseArray{BlockType,N} array::Array end function blocksparse_blocks(a::PermutedDimsArray) - return SparsePermutedDimsArrayBlocks(a) + return SparsePermutedDimsArrayBlocks{eltype(a),ndims(a),blocktype(parent(a)),typeof(a)}(a) end function Base.size(a::SparsePermutedDimsArrayBlocks) return _getindices(size(blocks(parent(a.array))), _perm(a.array)) @@ -158,11 +159,12 @@ reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index))) # Represents the array of arrays of a `Transpose` # wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`. -struct SparseTransposeBlocks{T,Array<:Transpose{T}} <: AbstractSparseMatrix{T} +struct SparseTransposeBlocks{T,BlockType<:AbstractMatrix{T},Array<:Transpose{T}} <: + AbstractSparseMatrix{BlockType} array::Array end function blocksparse_blocks(a::Transpose) - return SparseTransposeBlocks(a) + return SparseTransposeBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a) end function Base.size(a::SparseTransposeBlocks) return reverse(size(blocks(parent(a.array)))) @@ -192,11 +194,12 @@ end # Represents the array of arrays of a `Adjoint` # wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`. -struct SparseAdjointBlocks{T,Array<:Adjoint{T}} <: AbstractSparseMatrix{T} +struct SparseAdjointBlocks{T,BlockType<:AbstractMatrix{T},Array<:Adjoint{T}} <: + AbstractSparseMatrix{BlockType} array::Array end function blocksparse_blocks(a::Adjoint) - return SparseAdjointBlocks(a) + return SparseAdjointBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a) end function Base.size(a::SparseAdjointBlocks) return reverse(size(blocks(parent(a.array)))) diff --git a/NDTensors/src/lib/BlockSparseArrays/test/Project.toml b/NDTensors/src/lib/BlockSparseArrays/test/Project.toml index 0946bbaa98..b0460803d3 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/Project.toml +++ b/NDTensors/src/lib/BlockSparseArrays/test/Project.toml @@ -1,4 +1,6 @@ [deps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index ad4b5449bc..6604b21b2b 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -15,6 +15,7 @@ using BlockArrays: blocksizes, mortar using Compat: @compat +using GPUArraysCore: @allowscalar using LinearAlgebra: Adjoint, mul!, norm using NDTensors.BlockSparseArrays: @view!, @@ -28,27 +29,37 @@ using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: contract using Test: @test, @test_broken, @test_throws, @testset include("TestBlockSparseArraysUtils.jl") -@testset "BlockSparseArrays (eltype=$elt)" for elt in - (Float32, Float64, ComplexF32, ComplexF64) + +using NDTensors: NDTensors +include(joinpath(pkgdir(NDTensors), "test", "NDTensorsTestUtils", "NDTensorsTestUtils.jl")) +using .NDTensorsTestUtils: devices_list, is_supported_eltype +@testset "BlockSparseArrays (dev=$dev, eltype=$elt)" for dev in devices_list(copy(ARGS)), + elt in (Float32, Float64, Complex{Float32}, Complex{Float64}) + + @show dev, elt + + if !is_supported_eltype(dev, elt) + continue + end @testset "Broken" begin # TODO: Fix this and turn it into a proper test. - a = BlockSparseArray{elt}([2, 3], [2, 3]) - a[Block(1, 1)] = randn(elt, 2, 2) - a[Block(2, 2)] = randn(elt, 3, 3) + a = dev(BlockSparseArray{elt}([2, 3], [2, 3])) + a[Block(1, 1)] = dev(randn(elt, 2, 2)) + a[Block(2, 2)] = dev(randn(elt, 3, 3)) @test_broken a[:, 4] # TODO: Fix this and turn it into a proper test. - a = BlockSparseArray{elt}([2, 3], [2, 3]) - a[Block(1, 1)] = randn(elt, 2, 2) - a[Block(2, 2)] = randn(elt, 3, 3) + a = dev(BlockSparseArray{elt}([2, 3], [2, 3])) + a[Block(1, 1)] = dev(randn(elt, 2, 2)) + a[Block(2, 2)] = dev(randn(elt, 3, 3)) @test_broken a[:, [2, 4]] @test_broken a[[3, 5], [2, 4]] # TODO: Fix this and turn it into a proper test. - a = BlockSparseArray{elt}([2, 3], [2, 3]) - a[Block(1, 1)] = randn(elt, 2, 2) - a[Block(2, 2)] = randn(elt, 3, 3) - @test a[2:4, 4] == Array(a)[2:4, 4] + a = dev(BlockSparseArray{elt}([2, 3], [2, 3])) + a[Block(1, 1)] = dev(randn(elt, 2, 2)) + a[Block(2, 2)] = dev(randn(elt, 3, 3)) + @allowscalar @test a[2:4, 4] == Array(a)[2:4, 4] @test_broken a[4, 2:4] @test a[Block(1), :] isa BlockSparseArray{elt}