Skip to content

Commit

Permalink
Add and fix some more tests of slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jun 3, 2024
1 parent 8c677d6 commit 64e5597
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,22 @@ function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:Block{1}})
return sub_axis(a, Block(indices))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:BlockIndexRange{1}})
return sub_axis(a, indices.block)
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::Block)
return sub_axis(a, [indices])
return only(axes(blockedunitrange_getindices(a, indices)))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockIndexRange)
return only(axes(blockedunitrange_getindices(a, indices)))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
Expand Down Expand Up @@ -154,6 +166,10 @@ function blockrange(axis::AbstractUnitRange, r::Block{1})
return r:r
end

function blockrange(axis::AbstractUnitRange, r::BlockIndexRange)
return Block(r):Block(r)
end

function blockrange(axis::AbstractUnitRange, r)
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ end

# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
# TODO: Make more generic for GPU.
a_dest = BlockSparseArray{eltype(a)}(axes)
a_dest .= a
return a_dest
end

# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
)
# TODO: Make more generic for GPU.
a_dest = Array{eltype(a)}(undef, length.(axes))
a_dest .= a
return a_dest
end
18 changes: 18 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,24 @@ include("TestBlockSparseArraysUtils.jl")
# BlockArrays.jl v1.
@test b[Block(1, 1)] == trues(size(@view(b[Block(1, 1)])))

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
x = randn(elt, 1, 2)
@view(a[Block(2, 2)])[1:1, 1:2] = x
@test @view(a[Block(2, 2)])[1:1, 1:2] == x
@test a[Block(2, 2)][1:1, 1:2] == x

# TODO: This is broken, fix!
@test_broken a[3:3, 4:5] == x

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
x = randn(elt, 1, 2)
@views a[Block(2, 2)][1:1, 1:2] = x
@test @view(a[Block(2, 2)])[1:1, 1:2] == x
@test a[Block(2, 2)][1:1, 1:2] == x

# TODO: This is broken, fix!
@test_broken a[3:3, 4:5] == x

## Broken, need to fix.

# This is outputting only zero blocks.
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ function blockedunitrange_getindices(
return mortar(map(index -> a[index], indices))
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::BlockedUnitRange, indices::Block{1})
return a[indices]
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::BlockedUnitRange, indices)
return error("Not implemented.")
Expand Down

0 comments on commit 64e5597

Please sign in to comment.