From cce97430a03bbd568bd960f3db4b9d961c9d93de Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 9 Nov 2024 10:31:05 -0500 Subject: [PATCH] Add test for permutedims bug --- .../wrappedabstractblocksparsearray.jl | 19 +++++++++++++++++++ .../lib/BlockSparseArrays/test/test_basics.jl | 9 +++++++++ 2 files changed, 28 insertions(+) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 1a684ef08d..dfb797caaa 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -273,3 +273,22 @@ function Base.similar( ) return blocksparse_similar(a, elt, axes) end + +# TODO: Implement this in a more generic way using a smarter `copyto!`, +# which is ultimately what `Array{T,N}(::AbstractArray{<:Any,N})` calls. +# These are defined for now to avoid scalar indexing issues when there +# are blocks on GPU. +function Base.Array{T,N}(a::BlockSparseArrayLike{<:Any,N}) where {T,N} + # First make it dense, then move to CPU. + # Directly copying to CPU causes some issues with + # scalar indexing on GPU which we have to investigate. + a_dest = similartype(blocktype(a), T)(undef, size(a)) + a_dest .= a + return Array{T,N}(a_dest) +end +function Base.Array{T}(a::BlockSparseArrayLike) where {T} + return Array{T,ndims(a)}(a) +end +function Base.Array(a::BlockSparseArrayLike) + return Array{eltype(a)}(a) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 7172917caa..3cfaeded6d 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -297,6 +297,15 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype @test block_nstored(b) == 2 @test nstored(b) == 2 * 4 + 3 * 3 + a = dev(BlockSparseArray{elt}([1, 1, 1], [1, 2, 3], [2, 2, 1], [1, 2, 1])) + a[Block(3, 2, 2, 3)] = dev(randn(1, 2, 2, 1)) + perm = (2, 3, 4, 1) + for b in (PermutedDimsArray(a, perm), permutedims(a, perm)) + @test Array(b) == permutedims(Array(a), perm) + @test issetequal(block_stored_indices(b), [Block(2, 2, 3, 3)]) + @test @allowscalar b[Block(2, 2, 3, 3)] == permutedims(a[Block(3, 2, 2, 3)], perm) + end + a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) @views for b in [Block(1, 2), Block(2, 1)] a[b] = randn(elt, size(a[b]))