Skip to content

Commit

Permalink
Add test for permutedims bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 9, 2024
1 parent a2cf2f5 commit cce9743
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,22 @@ function Base.similar(
)
return blocksparse_similar(a, elt, axes)
end

# TODO: Implement this in a more generic way using a smarter `copyto!`,
# which is ultimately what `Array{T,N}(::AbstractArray{<:Any,N})` calls.
# These are defined for now to avoid scalar indexing issues when there
# are blocks on GPU.
function Base.Array{T,N}(a::BlockSparseArrayLike{<:Any,N}) where {T,N}
# First make it dense, then move to CPU.
# Directly copying to CPU causes some issues with
# scalar indexing on GPU which we have to investigate.
a_dest = similartype(blocktype(a), T)(undef, size(a))
a_dest .= a
return Array{T,N}(a_dest)
end
function Base.Array{T}(a::BlockSparseArrayLike) where {T}
return Array{T,ndims(a)}(a)
end
function Base.Array(a::BlockSparseArrayLike)
return Array{eltype(a)}(a)
end
9 changes: 9 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,15 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
@test block_nstored(b) == 2
@test nstored(b) == 2 * 4 + 3 * 3

a = dev(BlockSparseArray{elt}([1, 1, 1], [1, 2, 3], [2, 2, 1], [1, 2, 1]))
a[Block(3, 2, 2, 3)] = dev(randn(1, 2, 2, 1))
perm = (2, 3, 4, 1)
for b in (PermutedDimsArray(a, perm), permutedims(a, perm))
@test Array(b) == permutedims(Array(a), perm)
@test issetequal(block_stored_indices(b), [Block(2, 2, 3, 3)])
@test @allowscalar b[Block(2, 2, 3, 3)] == permutedims(a[Block(3, 2, 2, 3)], perm)
end

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
@views for b in [Block(1, 2), Block(2, 1)]
a[b] = randn(elt, size(a[b]))
Expand Down

0 comments on commit cce9743

Please sign in to comment.