Skip to content

Commit

Permalink
[GradedAxes] Introduce GradedUnitRangeDual (#1531)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe authored Nov 5, 2024
1 parent a5c3cf5 commit c49d7f2
Show file tree
Hide file tree
Showing 15 changed files with 585 additions and 321 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using BlockArrays:
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
using NDTensors.GradedAxes:
GradedAxes, GradedOneTo, UnitRangeDual, blocklabels, dual, gradedrange
GradedAxes,
GradedOneTo,
GradedUnitRange,
GradedUnitRangeDual,
blocklabels,
dual,
gradedrange,
isdual
using NDTensors.LabelledNumbers: label
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: fusedims, splitdims
using LinearAlgebra: adjoint
using Random: randn!
function blockdiagonal!(f, a::AbstractArray)
for i in 1:minimum(blocksize(a))
Expand All @@ -31,15 +40,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(d1, d2, d1, d2)
blockdiagonal!(randn!, a)
@test axes(a, 1) isa GradedOneTo
@test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo

for b in (a + a, 2 * a)
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
@test nstored(b) == 32
@test block_nstored(b) == 2
# TODO: Have to investigate why this fails
# on Julia v1.6, or drop support for v1.6.
for i in 1:ndims(a)
@test axes(b, i) isa GradedOneTo
end
Expand Down Expand Up @@ -103,16 +112,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test blocksize(m) == (3, 3)
@test a == splitdims(m, (d1, d2), (d1, d2))
end

@testset "dual axes" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r)))
a = BlockSparseArray{elt}(ax...)
@views for b in [Block(1, 1), Block(2, 2)]
a[b] = randn(elt, size(a[b]))
end
# TODO: Define and use `isdual` here.
for dim in 1:ndims(a)
@test typeof(ax[dim]) === typeof(axes(a, dim))
@test isdual(ax[dim]) == isdual(axes(a, dim))
end
@test @view(a[Block(1, 1)])[1, 1] == a[1, 1]
@test @view(a[Block(1, 1)])[2, 1] == a[2, 1]
Expand All @@ -130,41 +140,149 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a[I] == a_dense[I]
end
@test axes(a') == dual.(reverse(axes(a)))
# TODO: Define and use `isdual` here.
@test typeof(axes(a', 1)) === typeof(dual(axes(a, 2)))
@test typeof(axes(a', 2)) === typeof(dual(axes(a, 1)))

@test isdual(axes(a', 1)) isdual(axes(a, 2))
@test isdual(axes(a', 2)) isdual(axes(a, 1))
@test isnothing(show(devnull, MIME("text/plain"), a))

# Check preserving dual in tensor algebra.
for b in (a + a, 2 * a, 3 * a - a)
@test Array(b) 2 * Array(a)
# TODO: Define and use `isdual` here.
for dim in 1:ndims(a)
@test typeof(axes(b, dim)) === typeof(axes(b, dim))
@test isdual(axes(b, dim)) == isdual(axes(a, dim))
end
end

@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
end

@testset "GradedOneTo" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedOneTo
@test axes(a[:, :], i) isa GradedOneTo
end

I = [Block(1)[1:1]]
@test a[I, :] isa AbstractBlockArray
@test a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test !isdual(axes(a[I, I], 1))
end

@testset "GradedUnitRange" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedUnitRange
@test axes(a[:, :], i) isa GradedUnitRange
end

I = [Block(1)[1:1]]
@test a[I, :] isa AbstractBlockArray
@test axes(a[I, :], 1) isa GradedOneTo
@test axes(a[I, :], 2) isa GradedUnitRange

@test a[:, I] isa AbstractBlockArray
@test axes(a[:, I], 2) isa GradedOneTo
@test axes(a[:, I], 1) isa GradedUnitRange
@test size(a[I, I]) == (1, 1)
@test !isdual(axes(a[I, I], 1))
end

# Test case when all axes are dual.
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
@testset "dual GradedOneTo" begin
r = gradedrange([U1(-1) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for ax in axes(b)
@test ax isa UnitRangeDual
for i in 1:2
@test axes(b, i) isa GradedUnitRangeDual
@test axes(a[:, :], i) isa GradedUnitRangeDual
end
I = [Block(1)[1:1]]
@test a[I, :] isa AbstractBlockArray
@test a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
end

@testset "dual GradedUnitRange" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedUnitRangeDual
@test axes(a[:, :], i) isa GradedUnitRangeDual
end

I = [Block(1)[1:1]]
@test a[I, :] isa AbstractBlockArray
@test a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
end

# Test case when all axes are dual
# from taking the adjoint.
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
@testset "dual BlockedUnitRange" begin # self dual
r = blockedrange([2, 2])
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
@test a[:, :] isa BlockSparseArray
for i in 1:2
@test axes(b, i) isa BlockedOneTo
@test axes(a[:, :], i) isa BlockedOneTo
end

I = [Block(1)[1:1]]
@test a[I, :] isa BlockSparseArray
@test a[:, I] isa BlockSparseArray
@test size(a[I, I]) == (1, 1)
@test !isdual(axes(a[I, I], 1))
end

# Test case when all axes are dual from taking the adjoint.
for r in (
gradedrange([U1(0) => 2, U1(1) => 2]),
gradedrange([U1(0) => 2, U1(1) => 2])[begin:end],
)
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
Expand All @@ -173,8 +291,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)'
for ax in axes(b)
@test ax isa UnitRangeDual
@test ax isa typeof(dual(r))
end

I = [Block(1)[1:1]]
@test size(b[I, :]) == (1, 4)
@test size(b[:, I]) == (4, 1)
@test size(b[I, I]) == (1, 1)
end
end
@testset "Matrix multiplication" begin
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using BlockArrays:
BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock
AbstractBlockedUnitRange,
BlockArrays,
Block,
BlockIndexRange,
BlockedVector,
blocklength,
blocksize,
viewblock

# This splits `BlockIndexRange{N}` into
# `NTuple{N,BlockIndexRange{1}}`.
Expand Down Expand Up @@ -191,7 +198,9 @@ function to_blockindexrange(
# work right now.
return blocks(a.blocks)[Int(I)]
end
function to_blockindexrange(a::Base.Slice{<:BlockedOneTo{<:Integer}}, I::Block{1})
function to_blockindexrange(
a::Base.Slice{<:AbstractBlockedUnitRange{<:Integer}}, I::Block{1}
)
@assert I in only(blockaxes(a.indices))
return I
end
Expand Down
21 changes: 19 additions & 2 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ using BlockArrays:
blocksizes,
mortar
using Compat: @compat
using LinearAlgebra: mul!
using LinearAlgebra: Adjoint, mul!
using NDTensors.BlockSparseArrays:
@view!, BlockSparseArray, BlockView, block_nstored, block_reshape, view!
@view!,
BlockSparseArray,
BlockView,
block_nstored,
block_reshape,
block_stored_indices,
view!
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: contract
using Test: @test, @test_broken, @test_throws, @testset
Expand All @@ -44,6 +50,17 @@ include("TestBlockSparseArraysUtils.jl")
a[Block(2, 2)] = randn(elt, 3, 3)
@test a[2:4, 4] == Array(a)[2:4, 4]
@test_broken a[4, 2:4]

@test a[Block(1), :] isa BlockSparseArray{elt}
@test adjoint(a) isa Adjoint{elt,<:BlockSparseArray}
@test_broken adjoint(a)[Block(1), :] isa Adjoint{elt,<:BlockSparseArray}
# could also be directly a BlockSparseArray

a = BlockSparseArray{elt}([1], [1, 1])
a[1, 2] = 1
@test [a[Block(Tuple(it))] for it in eachindex(block_stored_indices(a))] isa Vector
ah = adjoint(a)
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
end
@testset "Basics" begin
a = BlockSparseArray{elt}([2, 3], [2, 3])
Expand Down
3 changes: 2 additions & 1 deletion NDTensors/src/lib/GradedAxes/src/GradedAxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module GradedAxes
include("blockedunitrange.jl")
include("gradedunitrange.jl")
include("dual.jl")
include("unitrangedual.jl")
include("gradedunitrangedual.jl")
include("onetoone.jl")
include("fusion.jl")
end
2 changes: 1 addition & 1 deletion NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ end
# Slice `a` by `I`, returning a:
# `BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}`
# with the `BlockIndex{1}` corresponding to each value of `I`.
function to_blockindices(a::BlockedOneTo{<:Integer}, I::UnitRange{<:Integer})
function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:Integer})
return mortar(
map(blocks(blockedunitrange_getindices(a, I))) do r
bi_first = findblockindex(a, first(r))
Expand Down
7 changes: 6 additions & 1 deletion NDTensors/src/lib/GradedAxes/src/dual.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
function dual end
# default behavior: self-dual
dual(r::AbstractUnitRange) = r
nondual(r::AbstractUnitRange) = r
isdual(::AbstractUnitRange) = false

using NDTensors.LabelledNumbers:
LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel

dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))
label_dual(x) = label_dual(LabelledStyle(x), x)
label_dual(::NotLabelled, x) = x
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))
Expand Down
15 changes: 4 additions & 11 deletions NDTensors/src/lib/GradedAxes/src/fusion.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
using BlockArrays: AbstractBlockedUnitRange, blocklengths

# Represents the range `1:1` or `Base.OneTo(1)`.
struct OneToOne{T} <: AbstractUnitRange{T} end
OneToOne() = OneToOne{Bool}()
Base.first(a::OneToOne) = one(eltype(a))
Base.last(a::OneToOne) = one(eltype(a))
BlockArrays.blockaxes(g::OneToOne) = (Block.(g),) # BlockArrays default crashes for OneToOne{Bool}

# https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl
# https://en.wikipedia.org/wiki/Tensor_product
# https://github.com/KeitaNakamura/Tensorial.jl
Expand All @@ -20,7 +13,7 @@ function tensor_product(
end

flip_dual(r::AbstractUnitRange) = r
flip_dual(r::UnitRangeDual) = flip(r)
flip_dual(r::GradedUnitRangeDual) = flip(r)
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
return tensor_product(flip_dual(a1), flip_dual(a2))
end
Expand Down Expand Up @@ -67,7 +60,7 @@ function tensor_product(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRan
return blockedrange(new_blocklengths)
end

# convention: sort UnitRangeDual according to nondual blocks
# convention: sort GradedUnitRangeDual according to nondual blocks
function blocksortperm(a::AbstractUnitRange)
return Block.(sortperm(blocklabels(nondual(a))))
end
Expand Down Expand Up @@ -102,7 +95,7 @@ function blockmergesort(g::AbstractGradedUnitRange)
return gradedrange(new_blocklengths)
end

blockmergesort(g::UnitRangeDual) = flip(blockmergesort(flip(g)))
blockmergesort(g::GradedUnitRangeDual) = flip(blockmergesort(flip(g)))
blockmergesort(g::AbstractUnitRange) = g

# fusion_product produces a sorted, non-dual GradedUnitRange
Expand All @@ -111,7 +104,7 @@ function fusion_product(g1, g2)
end

fusion_product(g::AbstractUnitRange) = blockmergesort(g)
fusion_product(g::UnitRangeDual) = fusion_product(flip(g))
fusion_product(g::GradedUnitRangeDual) = fusion_product(flip(g))

# recursive fusion_product. Simpler than reduce + fix type stability issues with reduce
function fusion_product(g1, g2, g3...)
Expand Down
Loading

0 comments on commit c49d7f2

Please sign in to comment.