Skip to content

Commit

Permalink
Merge branch 'main' into BlockSparseArrays_gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Nov 7, 2024
2 parents 7e663ed + 4299ab4 commit ad358af
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 25 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.56"
version = "0.3.57"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using Test: @test, @testset
using BlockArrays:
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
Expand Down Expand Up @@ -217,10 +217,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
@test isdual(axes(a[I, :], 1))
@test isdual(axes(a[:, I], 2))
@test isdual(axes(a[I, I], 1))
@test isdual(axes(a[I, I], 2))
end

@testset "dual GradedUnitRange" begin
Expand All @@ -243,10 +243,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
@test isdual(axes(a[I, :], 1))
@test isdual(axes(a[:, I], 2))
@test isdual(axes(a[I, I], 1))
@test isdual(axes(a[I, I], 2))
end

@testset "dual BlockedUnitRange" begin # self dual
Expand Down
1 change: 1 addition & 0 deletions NDTensors/src/lib/GradedAxes/src/GradedAxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module GradedAxes
include("blockedunitrange.jl")
include("gradedunitrange.jl")
include("dual.jl")
include("labelledunitrangedual.jl")
include("gradedunitrangedual.jl")
include("onetoone.jl")
include("fusion.jl")
Expand Down
5 changes: 3 additions & 2 deletions NDTensors/src/lib/GradedAxes/src/dual.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# default behavior: self-dual
dual(r::AbstractUnitRange) = r
# default behavior: any object is self-dual
dual(x) = x
nondual(r::AbstractUnitRange) = r
isdual(::AbstractUnitRange) = false

Expand All @@ -11,4 +11,5 @@ label_dual(x) = label_dual(LabelledStyle(x), x)
label_dual(::NotLabelled, x) = x
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))

flip(a::AbstractUnitRange) = dual(label_dual(a))
flip(g::AbstractGradedUnitRange) = dual(gradedrange(label_dual.(blocklengths(g))))
1 change: 1 addition & 0 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ end
# == is just a range comparison that ignores labels. Need dedicated function to check equality.
struct NoLabel end
blocklabels(r::AbstractUnitRange) = Fill(NoLabel(), blocklength(r))
blocklabels(la::LabelledUnitRange) = [label(la)]

function LabelledNumbers.labelled_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange)
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))
Expand Down
53 changes: 44 additions & 9 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer)
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1})
return label_dual(getindex(nondual(a), indices))
return dual(getindex(nondual(a), indices))
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange)
return label_dual(getindex(nondual(a), indices))
return dual(getindex(nondual(a), indices))
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockIndexRange)
return dual(nondual(a)[indices])
end

# fix ambiguity
Expand All @@ -49,20 +53,51 @@ function BlockArrays.blocklengths(a::GradedUnitRangeDual)
return dual.(blocklengths(nondual(a)))
end

function gradedunitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices)
# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
)
a_indices = getindex(nondual(a), indices)
return mortar([label_dual(b) for b in blocks(a_indices)])
v = mortar(dual.(blocks(a_indices)))
# flip v to stay consistent with other cases where axes(v) are used
return flip_blockvector(v)
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}})
return gradedunitrangedual_getindices_blocks(a, indices)
function blockedunitrange_getindices(
a::GradedUnitRangeDual,
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
)
v = mortar(map(b -> a[b], blocks(indices)))
# GradedOneTo appears in mortar
# flip v axis to preserve dual information
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]]))
return flip_blockvector(v)
end

function blockedunitrange_getindices(
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
a::GradedUnitRangeDual, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
)
return gradedunitrangedual_getindices_blocks(a, indices)
# Without converting `indices` to `Vector`,
# mapping `indices` outputs a `BlockVector`
# which is harder to reason about.
vblocks = map(index -> a[index], Vector(indices))
# We pass `length.(blocks)` to `mortar` in order
# to pass block labels to the axes of the output,
# if they exist. This makes it so that
# `only(axes(a[indices])) isa `GradedUnitRange`
# if `a isa `GradedUnitRange`, for example.

v = mortar(vblocks, length.(vblocks))
# GradedOneTo appears in mortar
# flip v axis to preserve dual information
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)]))
return flip_blockvector(v)
end

function flip_blockvector(v::BlockVector)
block_axes = flip.(axes(v))
flipped = mortar(vec.(blocks(v)), block_axes)
return flipped
end

Base.axes(a::GradedUnitRangeDual) = axes(nondual(a))
Expand Down
49 changes: 49 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# LabelledUnitRangeDual is obtained by slicing a GradedUnitRangeDual with a block

using ..LabelledNumbers: LabelledNumbers, label, labelled, unlabel

struct LabelledUnitRangeDual{T,NondualUnitRange<:AbstractUnitRange{T}} <:
AbstractUnitRange{T}
nondual_unitrange::NondualUnitRange
end

dual(a::LabelledUnitRange) = LabelledUnitRangeDual(a)
nondual(a::LabelledUnitRangeDual) = a.nondual_unitrange
dual(a::LabelledUnitRangeDual) = nondual(a)
label_dual(::IsLabelled, a::LabelledUnitRangeDual) = dual(label_dual(nondual(a)))
isdual(::LabelledUnitRangeDual) = true
blocklabels(la::LabelledUnitRangeDual) = [label(la)]

LabelledNumbers.label(a::LabelledUnitRangeDual) = dual(label(nondual(a)))
LabelledNumbers.unlabel(a::LabelledUnitRangeDual) = unlabel(nondual(a))
LabelledNumbers.LabelledStyle(::LabelledUnitRangeDual) = IsLabelled()

for f in [:first, :getindex, :last, :length, :step]
@eval Base.$f(a::LabelledUnitRangeDual, args...) =
labelled($f(unlabel(a), args...), label(a))
end

# fix ambiguities
Base.getindex(a::LabelledUnitRangeDual, i::Integer) = dual(nondual(a)[i])
function Base.getindex(a::LabelledUnitRangeDual, indices::AbstractUnitRange{<:Integer})
return dual(nondual(a)[indices])
end

function Base.iterate(a::LabelledUnitRangeDual, i)
i == last(a) && return nothing
next = convert(eltype(a), labelled(i + step(a), label(a)))
return (next, next)
end

function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRangeDual)
println(io, typeof(a))
return print(io, label(a), " => ", unlabel(a))
end

function Base.show(io::IO, a::LabelledUnitRangeDual)
return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a))
end

function Base.AbstractUnitRange{T}(a::LabelledUnitRangeDual) where {T}
return AbstractUnitRange{T}(nondual(a))
end
104 changes: 100 additions & 4 deletions NDTensors/src/lib/GradedAxes/test/test_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using NDTensors.GradedAxes:
AbstractGradedUnitRange,
GradedAxes,
GradedUnitRangeDual,
LabelledUnitRangeDual,
OneToOne,
blocklabels,
blockmergesortperm,
Expand All @@ -27,7 +28,8 @@ using NDTensors.GradedAxes:
gradedrange,
isdual,
nondual
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal
using NDTensors.LabelledNumbers:
LabelledInteger, LabelledUnitRange, label, label_type, labelled, labelled_isequal, unlabel
using Test: @test, @test_broken, @testset
struct U1
n::Int
Expand Down Expand Up @@ -58,6 +60,92 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
@test blockisequal(ad, a)
end

@testset "LabelledUnitRangeDual" begin
la = labelled(1:2, U1(1))
@test la isa LabelledUnitRange
@test label(la) == U1(1)
@test blocklabels(la) == [U1(1)]
@test unlabel(la) == 1:2
@test la == 1:2
@test !isdual(la)
@test labelled_isequal(la, la)
@test space_isequal(la, la)
@test label_type(la) == U1

@test iterate(la) == (1, 1)
@test iterate(la) == (1, 1)
@test iterate(la, 1) == (2, 2)
@test isnothing(iterate(la, 2))

lad = dual(la)
@test lad isa LabelledUnitRangeDual
@test label(lad) == U1(-1)
@test blocklabels(lad) == [U1(-1)]
@test unlabel(lad) == 1:2
@test lad == 1:2
@test labelled_isequal(lad, lad)
@test space_isequal(lad, lad)
@test !labelled_isequal(la, lad)
@test !space_isequal(la, lad)
@test isdual(lad)
@test nondual(lad) === la
@test dual(lad) === la
@test label_type(lad) == U1

@test iterate(lad) == (1, 1)
@test iterate(lad) == (1, 1)
@test iterate(lad, 1) == (2, 2)
@test isnothing(iterate(lad, 2))

lad2 = lad[1:1]
@test lad2 isa LabelledUnitRangeDual
@test label(lad2) == U1(-1)
@test unlabel(lad2) == 1:1

laf = flip(la)
@test laf isa LabelledUnitRangeDual
@test label(laf) == U1(1)
@test unlabel(laf) == 1:2
@test labelled_isequal(la, laf)
@test !space_isequal(la, laf)

ladf = flip(dual(la))
@test ladf isa LabelledUnitRange
@test label(ladf) == U1(-1)
@test unlabel(ladf) == 1:2

lafd = dual(flip(la))
@test lafd isa LabelledUnitRange
@test label(lafd) == U1(-1)
@test unlabel(lafd) == 1:2

# check default behavior for objects without dual
la = labelled(1:2, 'x')
lad = dual(la)
@test lad isa LabelledUnitRangeDual
@test label(lad) == 'x'
@test blocklabels(lad) == ['x']
@test unlabel(lad) == 1:2
@test lad == 1:2
@test labelled_isequal(lad, lad)
@test space_isequal(lad, lad)
@test labelled_isequal(la, lad)
@test !space_isequal(la, lad)
@test isdual(lad)
@test nondual(lad) === la
@test dual(lad) === la

laf = flip(la)
@test laf isa LabelledUnitRangeDual
@test label(laf) == 'x'
@test unlabel(laf) == 1:2

ladf = flip(lad)
@test ladf isa LabelledUnitRange
@test label(ladf) == 'x'
@test unlabel(ladf) == 1:2
end

@testset "GradedUnitRangeDual" begin
for a in
[gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]]
Expand Down Expand Up @@ -124,13 +212,21 @@ end
@test blockmergesortperm(a) == [Block(1), Block(2)]
@test blockmergesortperm(ad) == [Block(1), Block(2)]

@test_broken isdual(ad[Block(1)])
@test_broken isdual(ad[Block(1)[1:1]])
@test isdual(ad[Block(1)])
@test isdual(ad[Block(1)[1:1]])
@test ad[Block(1)] isa LabelledUnitRangeDual
@test ad[Block(1)[1:1]] isa LabelledUnitRangeDual
@test label(ad[Block(2)]) == U1(-1)
@test label(ad[Block(2)[1:1]]) == U1(-1)

I = mortar([Block(2)[1:1]])
g = ad[I]
@test length(g) == 1
@test label(first(g)) == U1(-1)
@test_broken isdual(g[Block(1)])
@test isdual(g[Block(1)])

@test isdual(axes(ad[[Block(1)]], 1)) # used in view(::BlockSparseVector, [Block(1)])
@test isdual(axes(ad[mortar([Block(1)[1:1]])], 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]])
end
end

Expand Down
9 changes: 9 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,12 @@ function Base.iterate(a::LabelledUnitRange, i)
next = convert(eltype(a), labelled(i + step(a), label(a)))
return (next, next)
end

function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRange)
println(io, typeof(a))
return print(io, label(a), " => ", unlabel(a))
end

function Base.show(io::IO, a::LabelledUnitRange)
return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a))
end

0 comments on commit ad358af

Please sign in to comment.