Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GradedAxes] Introduce LabelledUnitRangeDual #1571

Merged
merged 17 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using Test: @test, @testset
using BlockArrays:
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
Expand Down Expand Up @@ -217,10 +217,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
@test isdual(axes(a[I, :], 1))
@test isdual(axes(a[:, I], 2))
@test isdual(axes(a[I, I], 1))
@test isdual(axes(a[I, I], 2))
end

@testset "dual GradedUnitRange" begin
Expand All @@ -243,10 +243,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
@test isdual(axes(a[I, :], 1))
@test isdual(axes(a[:, I], 2))
@test isdual(axes(a[I, I], 1))
@test isdual(axes(a[I, I], 2))
end

@testset "dual BlockedUnitRange" begin # self dual
Expand Down
1 change: 1 addition & 0 deletions NDTensors/src/lib/GradedAxes/src/GradedAxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module GradedAxes
include("blockedunitrange.jl")
include("gradedunitrange.jl")
include("dual.jl")
include("labelledunitrangedual.jl")
include("gradedunitrangedual.jl")
include("onetoone.jl")
include("fusion.jl")
Expand Down
50 changes: 42 additions & 8 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer)
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1})
return label_dual(getindex(nondual(a), indices))
return dual(getindex(nondual(a), indices))
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange)
return label_dual(getindex(nondual(a), indices))
return dual(getindex(nondual(a), indices))
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockIndexRange)
return dual(nondual(a)[indices])
end

# fix ambiguity
Expand All @@ -49,20 +53,50 @@ function BlockArrays.blocklengths(a::GradedUnitRangeDual)
return dual.(blocklengths(nondual(a)))
end

function gradedunitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices)
# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
)
a_indices = getindex(nondual(a), indices)
return mortar([label_dual(b) for b in blocks(a_indices)])
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}})
return gradedunitrangedual_getindices_blocks(a, indices)
function blockedunitrange_getindices(
a::GradedUnitRangeDual,
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
)
v = mortar(map(b -> a[b], blocks(indices)))
# GradedOneTo appears in mortar
# flip v axis to preserve dual information
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]]))
return flip_blockvector(v)
end

function blockedunitrange_getindices(
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
a::GradedUnitRangeDual, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
)
return gradedunitrangedual_getindices_blocks(a, indices)
# Without converting `indices` to `Vector`,
# mapping `indices` outputs a `BlockVector`
# which is harder to reason about.
vblocks = map(index -> a[index], Vector(indices))
# We pass `length.(blocks)` to `mortar` in order
# to pass block labels to the axes of the output,
# if they exist. This makes it so that
# `only(axes(a[indices])) isa `GradedUnitRange`
# if `a isa `GradedUnitRange`, for example.

v = mortar(vblocks, length.(vblocks))
# GradedOneTo appears in mortar
# flip v axis to preserve dual information
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)]))
return flip_blockvector(v)
end

function flip_blockvector(v::BlockVector)
# TODO way to create BlockArray with specified axis without relying on internal?
block_axes = flip.(axes(v))
flipped = BlockArrays._BlockArray(vec.(blocks(v)), block_axes)
ogauthe marked this conversation as resolved.
Show resolved Hide resolved
return flipped
end

Base.axes(a::GradedUnitRangeDual) = axes(nondual(a))
Expand Down
38 changes: 38 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# LabelledUnitRangeDual is obtained by slicing a GradedUnitRangeDual with a block

using ..LabelledNumbers: LabelledNumbers, label, labelled, unlabel

struct LabelledUnitRangeDual{T,NondualUnitRange<:AbstractUnitRange{T}} <:
AbstractUnitRange{T}
nondual_unitrange::NondualUnitRange
end

dual(a::LabelledUnitRange) = LabelledUnitRangeDual(a)
nondual(a::LabelledUnitRangeDual) = a.nondual_unitrange
dual(a::LabelledUnitRangeDual) = nondual(a)
flip(a::LabelledUnitRangeDual) = dual(flip(nondual(a)))
isdual(::LabelledUnitRangeDual) = true

LabelledNumbers.label(a::LabelledUnitRangeDual) = dual(label(nondual(a)))
LabelledNumbers.unlabel(a::LabelledUnitRangeDual) = unlabel(nondual(a))

for f in [:first, :getindex, :last, :length, :step]
@eval Base.$f(a::LabelledUnitRangeDual, args...) =
labelled($f(unlabel(a), args...), label(a))
end

# fix ambiguities
Base.getindex(a::LabelledUnitRangeDual, i::Integer) = dual(nondual(a)[i])

function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRangeDual)
println(io, typeof(a))
return print(io, label(a), " => ", unlabel(a))
end

function Base.show(io::IO, a::LabelledUnitRangeDual)
return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a))
end

function Base.AbstractUnitRange{T}(a::LabelledUnitRangeDual) where {T}
return AbstractUnitRange{T}(nondual(a))
end
36 changes: 32 additions & 4 deletions NDTensors/src/lib/GradedAxes/test/test_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using NDTensors.GradedAxes:
AbstractGradedUnitRange,
GradedAxes,
GradedUnitRangeDual,
LabelledUnitRangeDual,
OneToOne,
blocklabels,
blockmergesortperm,
Expand All @@ -27,7 +28,8 @@ using NDTensors.GradedAxes:
gradedrange,
isdual,
nondual
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal
using NDTensors.LabelledNumbers:
LabelledInteger, LabelledUnitRange, label, labelled, labelled_isequal, unlabel
using Test: @test, @test_broken, @testset
struct U1
n::Int
Expand Down Expand Up @@ -58,6 +60,24 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
@test blockisequal(ad, a)
end

@testset "LabelledUnitRangeDual" begin
la = labelled(1:2, U1(1))
@test la isa LabelledUnitRange
@test label(la) == U1(1)
@test unlabel(la) == 1:2
@test la == 1:2
@test !isdual(la)

lad = dual(la)
@test lad isa LabelledUnitRangeDual
@test label(lad) == U1(-1)
@test unlabel(lad) == 1:2
@test lad == 1:2
@test isdual(lad)
@test nondual(lad) === la
@test dual(lad) === la
end

@testset "GradedUnitRangeDual" begin
for a in
[gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]]
Expand Down Expand Up @@ -124,13 +144,21 @@ end
@test blockmergesortperm(a) == [Block(1), Block(2)]
@test blockmergesortperm(ad) == [Block(1), Block(2)]

@test_broken isdual(ad[Block(1)])
@test_broken isdual(ad[Block(1)[1:1]])
@test isdual(ad[Block(1)])
@test isdual(ad[Block(1)[1:1]])
@test ad[Block(1)] isa LabelledUnitRangeDual
@test ad[Block(1)[1:1]] isa LabelledUnitRangeDual
@test label(ad[Block(2)]) == U1(-1)
@test label(ad[Block(2)[1:1]]) == U1(-1)

I = mortar([Block(2)[1:1]])
g = ad[I]
@test length(g) == 1
@test label(first(g)) == U1(-1)
@test_broken isdual(g[Block(1)])
@test isdual(g[Block(1)])

@test isdual(axes(ad[[Block(1)]], 1)) # used in view(::BlockSparseVector, [Block(1)])
@test isdual(axes(ad[mortar([Block(1)[1:1]])], 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]])
end
end

Expand Down
Loading