From 6d6b47f225198cafdf7879b6098a399fe24f7199 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 1 Jul 2024 14:56:45 -0400 Subject: [PATCH] Fix splitdims --- .../src/BlockSparseArraysGradedAxesExt.jl | 7 ++++++- .../ext/BlockSparseArraysGradedAxesExt/test/runtests.jl | 2 +- .../BlockSparseArrays/src/abstractblocksparsearray/map.jl | 4 ++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl index adf7b35de0..3a815d4ea2 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl @@ -66,7 +66,12 @@ function TensorAlgebra.splitdims( return length(axis) ≤ length(axes(a, i)) end blockperms = invblockperm.(blocksortperm.(axes_prod)) - a_blockpermed = a[blockperms...] + # TODO: This is doing extra copies of the blocks, + # use `@view a[axes_prod...]` instead. + # That will require implementing some reindexing logic + # for this combination of slicing. + a_unblocked = a[axes_prod...] + a_blockpermed = a_unblocked[blockperms...] return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...) end diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 3b12cdca47..38142b65f5 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -101,7 +101,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a[1, 1, 1, 1] == m[1, 1] @test a[2, 2, 2, 2] == m[4, 4] @test blocksize(m) == (3, 3) - @test_broken a == splitdims(m, (d1, d2), (d1, d2)) + @test a == splitdims(m, (d1, d2), (d1, d2)) end @testset "dual axes" begin r = gradedrange([U1(0) => 2, U1(1) => 2]) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl index 14e431efe7..b9ab510566 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -57,6 +57,10 @@ function reblock( return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...] end +# TODO: Rewrite this so that it takes the blocking structure +# made by combining the blocking of the axes (i.e. the blocking that +# is used to determine `union_stored_blocked_cartesianindices(...)`). +# `reblock` is a partial solution to that, but a bit ad-hoc. # TODO: Move to `blocksparsearrayinterface/map.jl`. function SparseArrayInterface.sparse_map!( ::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}