From b2fd32b5be1857ebe59dfa40d74d986930c30b97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 1 Nov 2024 19:06:30 -0400 Subject: [PATCH] fix slicing BlockSparseArrays --- .../src/BlockArraysExtensions/BlockArraysExtensions.jl | 2 +- .../blocksparsearrayinterface.jl | 8 ++++---- NDTensors/src/lib/GradedAxes/test/test_basics.jl | 5 ++++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 7e7b503475..3214a8a230 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -20,7 +20,7 @@ using BlockArrays: findblockindex using Compat: allequal using Dictionaries: Dictionary, Indices -using ..GradedAxes: blockedunitrange_getindices, to_blockindices +using ..GradedAxes: blockedunitrange_getindices, gradedunitrange_getindices, to_blockindices using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices # A return type for `blocks(array)` when `array` isn't blocked. diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 182504e038..55f202a6a7 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -30,7 +30,7 @@ end # https://github.com/ITensor/ITensors.jl/issues/1336. function blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}}) bs1 = to_blockindices(inds[1], I[1]) - I1 = BlockSlice(bs1, blockedunitrange_getindices(inds[1], I[1])) + I1 = BlockSlice(bs1, gradedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -45,7 +45,7 @@ end # a[[Block(2), Block(1)], [Block(2), Block(1)]] function blocksparse_to_indices(a, inds, I::Tuple{Vector{<:Block{1}},Vararg{Any}}) - I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) + I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -54,7 +54,7 @@ end function blocksparse_to_indices( a, inds, I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},Vararg{Any}} ) - I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) + I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -64,7 +64,7 @@ end function blocksparse_to_indices( a, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}} ) - I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) + I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 02d37f718f..15e04fec87 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -9,7 +9,8 @@ using BlockArrays: blocklasts, blocklength, blocklengths, - blocks + blocks, + combine_blockaxes using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel @@ -94,6 +95,8 @@ end @test length(a[Block(2)]) == 3 @test blocklengths(only(axes(a))) == blocklengths(a) @test blocklabels(only(axes(a))) == blocklabels(a) + + @test combine_blockaxes(a, a) isa GradedOneTo end # Slicing operations