Skip to content

Commit

Permalink
fix slicing BlockSparseArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Nov 1, 2024
1 parent 24fb3c8 commit b2fd32b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using BlockArrays:
findblockindex
using Compat: allequal
using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices, to_blockindices
using ..GradedAxes: blockedunitrange_getindices, gradedunitrange_getindices, to_blockindices
using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices

# A return type for `blocks(array)` when `array` isn't blocked.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end
# https://github.com/ITensor/ITensors.jl/issues/1336.
function blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})
bs1 = to_blockindices(inds[1], I[1])
I1 = BlockSlice(bs1, blockedunitrange_getindices(inds[1], I[1]))
I1 = BlockSlice(bs1, gradedunitrange_getindices(inds[1], I[1]))
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

Expand All @@ -45,7 +45,7 @@ end

# a[[Block(2), Block(1)], [Block(2), Block(1)]]
function blocksparse_to_indices(a, inds, I::Tuple{Vector{<:Block{1}},Vararg{Any}})
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1]))
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

Expand All @@ -54,7 +54,7 @@ end
function blocksparse_to_indices(
a, inds, I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},Vararg{Any}}
)
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1]))
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

Expand All @@ -64,7 +64,7 @@ end
function blocksparse_to_indices(
a, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}}
)
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1]))
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

Expand Down
5 changes: 4 additions & 1 deletion NDTensors/src/lib/GradedAxes/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using BlockArrays:
blocklasts,
blocklength,
blocklengths,
blocks
blocks,
combine_blockaxes
using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange
using NDTensors.LabelledNumbers:
LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel
Expand Down Expand Up @@ -94,6 +95,8 @@ end
@test length(a[Block(2)]) == 3
@test blocklengths(only(axes(a))) == blocklengths(a)
@test blocklabels(only(axes(a))) == blocklabels(a)

@test combine_blockaxes(a, a) isa GradedOneTo
end

# Slicing operations
Expand Down

0 comments on commit b2fd32b

Please sign in to comment.