Skip to content

Commit

Permalink
fix combine_blockaxes
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Nov 1, 2024
1 parent de6080b commit 24fb3c8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(d1, d2, d1, d2)
blockdiagonal!(randn!, a)
@test axes(a, 1) isa GradedOneTo
@test axes(view(a, 1:4, 1:4), 1) isa GradedOneTo

for b in (a + a, 2 * a)
@test size(b) == (4, 4, 4, 4)
Expand Down
15 changes: 12 additions & 3 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ using BlockArrays:
blocklengths,
findblock,
findblockindex,
mortar
mortar,
sortedunion
using Compat: allequal
using FillArrays: Fill
using ..LabelledNumbers:
Expand Down Expand Up @@ -304,17 +305,25 @@ end
# that mixed dense and graded axes.
# TODO: Maybe come up with a more general solution.
function BlockArrays.combine_blockaxes(
a1::GradedOneTo{<:LabelledInteger{T}}, a2::Base.OneTo{T}
a1::AbstractGradedUnitRange{T}, a2::Base.OneTo{T}
) where {T<:Integer}
combined_blocklasts = sort!(union(unlabel.(blocklasts(a1)), blocklasts(a2)))
return BlockedOneTo(combined_blocklasts)
end
function BlockArrays.combine_blockaxes(
a1::Base.OneTo{T}, a2::GradedOneTo{<:LabelledInteger{T}}
a1::Base.OneTo{T}, a2::AbstractGradedUnitRange{T}
) where {T<:Integer}
return BlockArrays.combine_blockaxes(a2, a1)
end

# preserve labels inside combine_blockaxes
# TODO dual
function BlockArrays.combine_blockaxes(
a::AbstractGradedUnitRange, b::AbstractGradedUnitRange
)
return gradedrange(sortedunion(blocklasts(a), blocklasts(b)))
end

# Version of length that checks that all blocks have the same label
# and returns a labelled length with that label.
function labelled_length(a::AbstractBlockVector{<:Integer})
Expand Down

0 comments on commit 24fb3c8

Please sign in to comment.