From b6ebf3840068f2b2401732bef0ec90fa69b23be4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 1 Jul 2024 13:04:22 -0400 Subject: [PATCH] Preserve block labels during fusion --- .../test/runtests.jl | 16 +++------- .../lib/GradedAxes/src/blockedunitrange.jl | 16 +++++++++- .../src/lib/GradedAxes/src/gradedunitrange.jl | 31 ++++++++++++++++++- 3 files changed, 49 insertions(+), 14 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index a7a54fb9cb..95e8a71fb9 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -87,14 +87,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a = BlockSparseArray{elt}(d1, d2, d1, d2) blockdiagonal!(randn!, a) m = fusedims(a, (1, 2), (3, 4)) - # TODO: Once block merging is implemented, this should - # be the real test. for ax in axes(m) @test ax isa GradedOneTo - # TODO: Current `fusedims` doesn't merge - # common sectors, need to fix. - @test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)] - @test blocklabels(ax) == [U1(0), U1(1), U1(1), U1(2)] + @test blocklabels(ax) == [U1(0), U1(1), U1(2)] end for I in CartesianIndices(m) if I ∈ CartesianIndex.([(1, 1), (4, 4)]) @@ -104,12 +99,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end end @test a[1, 1, 1, 1] == m[1, 1] - @test a[2, 2, 2, 2] == m[4, 4] - # TODO: Current `fusedims` doesn't merge - # common sectors, need to fix. - @test_broken blocksize(m) == (3, 3) - @test blocksize(m) == (4, 4) - @test a == splitdims(m, (d1, d2), (d1, d2)) + @test_broken a[2, 2, 2, 2] == m[4, 4] + @test blocksize(m) == (3, 3) + @test_broken a == splitdims(m, (d1, d2), (d1, d2)) end @testset "dual axes" begin r = gradedrange([U1(0) => 2, U1(1) => 2]) diff --git a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl index 5512126ad8..883025df12 100644 --- a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl @@ -85,7 +85,21 @@ end function blockedunitrange_getindices( a::AbstractBlockedUnitRange, indices::AbstractBlockVector{<:Block{1}} ) - return mortar(map(bs -> mortar(map(b -> a[b], bs)), blocks(indices))) + blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)) + # We pass `length.(blks)` to `mortar` in order + # to pass block labels to the axes of the output, + # if they exist. This makes it so that + # `only(axes(a[indices])) isa `GradedUnitRange` + # if `a isa `GradedUnitRange`, for example. + # Note there is a more specialized definition: + # ```julia + # function blockedunitrange_getindices( + # a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}} + # ) + # ``` + # that does a better job of preserving labels, since `length` + # may drop labels for certain block types. + return mortar(blks, length.(blks)) end # TODO: Move this to a `BlockArraysExtensions` library. diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index f5c27b1c55..84f5ccb10a 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -18,7 +18,9 @@ using BlockArrays: findblock, findblockindex, mortar -using ..LabelledNumbers: LabelledNumbers, LabelledInteger, label, labelled, unlabel +using Compat: allequal +using ..LabelledNumbers: + LabelledNumbers, LabelledInteger, LabelledUnitRange, label, labelled, unlabel const AbstractGradedUnitRange{T<:LabelledInteger} = AbstractBlockedUnitRange{T} @@ -292,3 +294,30 @@ function BlockArrays.combine_blockaxes( ) where {T<:Integer} return BlockArrays.combine_blockaxes(a2, a1) end + +# Version of length that checks that all blocks have the same label +# and returns a labelled length with that label. +function labelled_length(a::AbstractBlockVector{<:Integer}) + blocklabels = label.(blocks(a)) + @assert allequal(blocklabels) + return labelled(unlabel(length(a)), first(blocklabels)) +end + +# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly. +# TODO: Make a special case for `BlockedVector{<:Block{1},<:BlockRange{1}}`? +# For example: +# ```julia +# blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices)) +# return blockedrange(blocklengths) +# ``` +function blockedunitrange_getindices( + a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}} +) + blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)) + # We pass `length.(blks)` to `mortar` in order + # to pass block labels to the axes of the output, + # if they exist. This makes it so that + # `only(axes(a[indices])) isa `GradedUnitRange` + # if `a isa `GradedUnitRange`, for example. + return mortar(blks, labelled_length.(blks)) +end