Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GradedAxes] [BlockSparseArrays] Fix ambiguity error when slicing GradedUnitRange with BlockSlice #1491

Merged
merged 6 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.24"
version = "0.3.25"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,37 +95,48 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
end
@testset "dual axes" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(dual(r), r)
@views for b in [Block(1, 1), Block(2, 2)]
a[b] = randn(elt, size(a[b]))
end
# TODO: Define and use `isdual` here.
@test axes(a, 1) isa UnitRangeDual
@test axes(a, 2) isa GradedUnitRange
@test !(axes(a, 2) isa UnitRangeDual)
a_dense = Array(a)
@test eachindex(a) == CartesianIndices(size(a))
for I in eachindex(a)
@test a[I] == a_dense[I]
end
@test axes(a') == dual.(reverse(axes(a)))
# TODO: Define and use `isdual` here.
@test axes(a', 1) isa UnitRangeDual
@test axes(a', 2) isa GradedUnitRange
@test !(axes(a', 2) isa UnitRangeDual)
@test isnothing(show(devnull, MIME("text/plain"), a))

# Check preserving dual in tensor algebra.
for b in (a + a, 2 * a, 3 * a - a)
@test Array(b) ≈ 2 * Array(a)
for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r)))
a = BlockSparseArray{elt}(ax...)
@views for b in [Block(1, 1), Block(2, 2)]
a[b] = randn(elt, size(a[b]))
end
# TODO: Define and use `isdual` here.
@test axes(b, 1) isa UnitRangeDual
@test axes(b, 2) isa GradedUnitRange
@test !(axes(b, 2) isa UnitRangeDual)
end
for dim in 1:ndims(a)
@test typeof(ax[dim]) === typeof(axes(a, dim))
end
@test @view(a[Block(1, 1)])[1, 1] == a[1, 1]
@test @view(a[Block(1, 1)])[2, 1] == a[2, 1]
@test @view(a[Block(1, 1)])[1, 2] == a[1, 2]
@test @view(a[Block(1, 1)])[2, 2] == a[2, 2]
@test @view(a[Block(2, 2)])[1, 1] == a[3, 3]
@test @view(a[Block(2, 2)])[2, 1] == a[4, 3]
@test @view(a[Block(2, 2)])[1, 2] == a[3, 4]
@test @view(a[Block(2, 2)])[2, 2] == a[4, 4]
@test @view(a[Block(1, 1)])[1:2, 1:2] == a[1:2, 1:2]
@test @view(a[Block(2, 2)])[1:2, 1:2] == a[3:4, 3:4]
a_dense = Array(a)
@test eachindex(a) == CartesianIndices(size(a))
for I in eachindex(a)
@test a[I] == a_dense[I]
end
@test axes(a') == dual.(reverse(axes(a)))
# TODO: Define and use `isdual` here.
@test typeof(axes(a', 1)) === typeof(dual(axes(a, 2)))
@test typeof(axes(a', 2)) === typeof(dual(axes(a, 1)))
@test isnothing(show(devnull, MIME("text/plain"), a))

@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
# Check preserving dual in tensor algebra.
for b in (a + a, 2 * a, 3 * a - a)
@test Array(b) ≈ 2 * Array(a)
# TODO: Define and use `isdual` here.
for dim in 1:ndims(a)
@test typeof(axes(b, dim)) === typeof(axes(b, dim))
end
end

@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
end

# Test case when all axes are dual.
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
struct GenericBlockSlice{B,T<:Integer,I<:AbstractUnitRange{T}} <: AbstractUnitRange{T}
block::B
indices::I
end
BlockArrays.Block(bs::GenericBlockSlice{<:Block}) = bs.block
for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe_length)
@eval Base.$f(S::GenericBlockSlice) = Base.$f(S.indices)
end
Base.getindex(S::GenericBlockSlice, i::Integer) = getindex(S.indices, i)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
struct BlockIndices{B,T<:Integer,I<:AbstractVector{T}} <: AbstractVector{T}
blocks::B
indices::I
Expand Down Expand Up @@ -175,6 +190,13 @@ function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
end

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
function blockrange(axis::AbstractUnitRange, r::GenericBlockSlice)
return blockrange(axis, r.block)
end

function blockrange(a::AbstractUnitRange, r::BlockIndices)
return blockrange(a, r.blocks)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ end
# TODO: Move to blocksparsearrayinterface.
function blocksparse_unblock(a, inds, I::Tuple{AbstractUnitRange{<:Integer},Vararg{Any}})
bs = blockrange(inds[1], I[1])
return BlockSlice(bs, blockedunitrange_getindices(inds[1], I[1]))
# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
return GenericBlockSlice(bs, blockedunitrange_getindices(inds[1], I[1]))
end

function BlockArrays.unblock(a, inds, I::Tuple{AbstractVector{<:Block{1}},Vararg{Any}})
Expand Down
7 changes: 7 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using BlockArrays:
Block,
BlockIndexRange,
BlockRange,
BlockSlice,
BlockedUnitRange,
block,
blockindex,
Expand Down Expand Up @@ -73,6 +74,12 @@ function blockedunitrange_getindices(a::BlockedUnitRange, indices::BlockIndexRan
return a[block(indices)][only(indices.indices)]
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::BlockedUnitRange, indices::BlockSlice)
# TODO: Is this a good definition? It ignores `indices.indices`.
return a[indices.block]
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::BlockedUnitRange, indices::Vector{<:Integer})
return map(index -> a[index], indices)
Expand Down
23 changes: 23 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using BlockArrays:
BlockedUnitRange,
BlockIndex,
BlockRange,
BlockSlice,
BlockVector,
blockedrange,
BlockIndexRange,
Expand Down Expand Up @@ -165,6 +166,15 @@ function blockedunitrange_getindices(
return labelled_blocks(a_indices, blocklabels(ga, indices))
end

# Fixes ambiguity error with:
# ```julia
# blockedunitrange_getindices(::GradedUnitRange, ::AbstractUnitRange{<:Integer})
# ```
# TODO: Try removing once GradedAxes is rewritten for BlockArrays v1.
function blockedunitrange_getindices(a::GradedUnitRange, indices::BlockSlice)
return a[indices.block]
end

function blockedunitrange_getindices(ga::GradedUnitRange, indices::BlockRange)
return labelled_blocks(unlabel_blocks(ga)[indices], blocklabels(ga, indices))
end
Expand Down Expand Up @@ -200,6 +210,19 @@ function Base.getindex(a::GradedUnitRange, indices::BlockIndex{1})
return blockedunitrange_getindices(a, indices)
end

# Fixes ambiguity issues with:
# ```julia
# getindex(::BlockedUnitRange, ::BlockSlice)
# getindex(::GradedUnitRange, ::AbstractUnitRange{<:Integer})
# getindex(::GradedUnitRange, ::Any)
# getindex(::AbstractUnitRange, ::AbstractUnitRange{<:Integer})
# ```
# TODO: Maybe not needed once GradedAxes is rewritten
# for BlockArrays v1.
function Base.getindex(a::GradedUnitRange, indices::BlockSlice)
return blockedunitrange_getindices(a, indices)
end

function Base.getindex(a::GradedUnitRange, indices)
return blockedunitrange_getindices(a, indices)
end
Expand Down
15 changes: 15 additions & 0 deletions NDTensors/src/lib/GradedAxes/test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@eval module $(gensym())
using BlockArrays:
Block,
BlockSlice,
BlockVector,
blockedrange,
blockfirsts,
Expand Down Expand Up @@ -87,6 +88,20 @@ using Test: @test, @test_broken, @testset
@test blocklengths(ax) == blocklengths(a)
@test blocklabels(ax) == blocklabels(a)

# Regression test for ambiguity error.
x = gradedrange(["x" => 2, "y" => 3])
a = x[BlockSlice(Block(1), Base.OneTo(2))]
@test length(a) == 2
@test a == 1:2
@test blocklength(a) == 1
# TODO: Should this be a `GradedUnitRange`,
# or maybe just a `LabelledUnitRange`?
@test a isa LabelledUnitRange
@test length(a[Block(1)]) == 2
@test label(a) == "x"
@test a[Block(1)] == 1:2
@test label(a[Block(1)]) == "x"

x = gradedrange(["x" => 2, "y" => 3])
a = x[3:4]
@test a isa GradedUnitRange
Expand Down
1 change: 1 addition & 0 deletions NDTensors/src/lib/LabelledNumbers/src/LabelledNumbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ include("labellednumber.jl")
include("labelledinteger.jl")
include("labelledarray.jl")
include("labelledunitrange.jl")
include("LabelledNumbersBlockArraysExt.jl")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using BlockArrays: BlockArrays, Block, blockaxes, blockfirsts, blocklasts

# Fixes ambiguity error with:
# ```julia
# getindex(::LabelledUnitRange, ::Any...)
# getindex(::AbstractArray{<:Any,N}, ::Block{N}) where {N}
# getindex(::AbstractArray, ::Block{1}, ::Any...)
# ```
function Base.getindex(a::LabelledUnitRange, index::Block{1})
@boundscheck index == Block(1) || throw(BlockBoundsError(a, index))
return a
end

function BlockArrays.blockaxes(a::LabelledUnitRange)
return blockaxes(unlabel(a))
end
function BlockArrays.blockfirsts(a::LabelledUnitRange)
return blockfirsts(unlabel(a))
end
function BlockArrays.blocklasts(a::LabelledUnitRange)
return blocklasts(unlabel(a))
end
1 change: 1 addition & 0 deletions NDTensors/src/lib/LabelledNumbers/test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15 changes: 14 additions & 1 deletion NDTensors/src/lib/LabelledNumbers/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@eval module $(gensym())
using LinearAlgebra: norm
using NDTensors.LabelledNumbers: LabelledInteger, islabelled, label, labelled, unlabel
using NDTensors.LabelledNumbers:
LabelledInteger, LabelledUnitRange, islabelled, label, labelled, unlabel
using Test: @test, @testset
@testset "LabelledNumbers" begin
@testset "Labelled number ($n)" for n in (2, 2.0)
Expand Down Expand Up @@ -112,4 +113,16 @@ using Test: @test, @testset
end
end
end

using BlockArrays: Block, blockaxes, blocklength, blocklengths
@testset "LabelledNumbersBlockArraysExt" begin
x = labelled(1:2, "x")
@test blockaxes(x) == (Block.(1:1),)
@test blocklength(x) == 1
@test blocklengths(x) == [2]
a = x[Block(1)]
@test a == 1:2
@test a isa LabelledUnitRange
@test label(a) == "x"
end
end
Loading