Skip to content

Commit

Permalink
Start adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 7, 2024
1 parent d1fadaf commit 49cc24c
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ function Base.similar(
end

function blocksparse_similar(a, elt::Type, axes::Tuple)
return BlockSparseArray{elt,length(axes),similartype(blocktype(a), axes)}(undef, axes)
return BlockSparseArray{elt,length(axes),similartype(blocktype(a), elt, axes)}(
undef, axes
)
end

# Needed by `BlockArrays` matrix multiplication interface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ _getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), i

# Represents the array of arrays of a `PermutedDimsArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`.
struct SparsePermutedDimsArrayBlocks{T,N,Array<:PermutedDimsArray{T,N}} <:
AbstractSparseArray{T,N}
struct SparsePermutedDimsArrayBlocks{
T,N,BlockType<:AbstractArray{T,N},Array<:PermutedDimsArray{T,N}
} <: AbstractSparseArray{BlockType,N}
array::Array
end
function blocksparse_blocks(a::PermutedDimsArray)
return SparsePermutedDimsArrayBlocks(a)
return SparsePermutedDimsArrayBlocks{eltype(a),ndims(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparsePermutedDimsArrayBlocks)
return _getindices(size(blocks(parent(a.array))), _perm(a.array))
Expand Down Expand Up @@ -158,11 +159,12 @@ reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))

# Represents the array of arrays of a `Transpose`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
struct SparseTransposeBlocks{T,Array<:Transpose{T}} <: AbstractSparseMatrix{T}
struct SparseTransposeBlocks{T,BlockType<:AbstractMatrix{T},Array<:Transpose{T}} <:
AbstractSparseMatrix{BlockType}
array::Array
end
function blocksparse_blocks(a::Transpose)
return SparseTransposeBlocks(a)
return SparseTransposeBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparseTransposeBlocks)
return reverse(size(blocks(parent(a.array))))
Expand Down Expand Up @@ -192,11 +194,12 @@ end

# Represents the array of arrays of a `Adjoint`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
struct SparseAdjointBlocks{T,Array<:Adjoint{T}} <: AbstractSparseMatrix{T}
struct SparseAdjointBlocks{T,BlockType<:AbstractMatrix{T},Array<:Adjoint{T}} <:
AbstractSparseMatrix{BlockType}
array::Array
end
function blocksparse_blocks(a::Adjoint)
return SparseAdjointBlocks(a)
return SparseAdjointBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparseAdjointBlocks)
return reverse(size(blocks(parent(a.array))))
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
35 changes: 23 additions & 12 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using BlockArrays:
blocksizes,
mortar
using Compat: @compat
using GPUArraysCore: @allowscalar
using LinearAlgebra: Adjoint, mul!, norm
using NDTensors.BlockSparseArrays:
@view!,
Expand All @@ -28,27 +29,37 @@ using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: contract
using Test: @test, @test_broken, @test_throws, @testset
include("TestBlockSparseArraysUtils.jl")
@testset "BlockSparseArrays (eltype=$elt)" for elt in
(Float32, Float64, ComplexF32, ComplexF64)

using NDTensors: NDTensors
include(joinpath(pkgdir(NDTensors), "test", "NDTensorsTestUtils", "NDTensorsTestUtils.jl"))
using .NDTensorsTestUtils: devices_list, is_supported_eltype
@testset "BlockSparseArrays (dev=$dev, eltype=$elt)" for dev in devices_list(copy(ARGS)),
elt in (Float32, Float64, Complex{Float32}, Complex{Float64})

@show dev, elt

if !is_supported_eltype(dev, elt)
continue
end
@testset "Broken" begin
# TODO: Fix this and turn it into a proper test.
a = BlockSparseArray{elt}([2, 3], [2, 3])
a[Block(1, 1)] = randn(elt, 2, 2)
a[Block(2, 2)] = randn(elt, 3, 3)
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a[Block(1, 1)] = dev(randn(elt, 2, 2))
a[Block(2, 2)] = dev(randn(elt, 3, 3))
@test_broken a[:, 4]

# TODO: Fix this and turn it into a proper test.
a = BlockSparseArray{elt}([2, 3], [2, 3])
a[Block(1, 1)] = randn(elt, 2, 2)
a[Block(2, 2)] = randn(elt, 3, 3)
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a[Block(1, 1)] = dev(randn(elt, 2, 2))
a[Block(2, 2)] = dev(randn(elt, 3, 3))
@test_broken a[:, [2, 4]]
@test_broken a[[3, 5], [2, 4]]

# TODO: Fix this and turn it into a proper test.
a = BlockSparseArray{elt}([2, 3], [2, 3])
a[Block(1, 1)] = randn(elt, 2, 2)
a[Block(2, 2)] = randn(elt, 3, 3)
@test a[2:4, 4] == Array(a)[2:4, 4]
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a[Block(1, 1)] = dev(randn(elt, 2, 2))
a[Block(2, 2)] = dev(randn(elt, 3, 3))
@allowscalar @test a[2:4, 4] == Array(a)[2:4, 4]
@test_broken a[4, 2:4]

@test a[Block(1), :] isa BlockSparseArray{elt}
Expand Down

0 comments on commit 49cc24c

Please sign in to comment.