Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] GradedAxes library #1271

Merged
merged 4 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ for lib in [
:SparseArrayDOKs,
:DiagonalArrays,
:BlockSparseArrays,
:GradedAxes,
:NamedDimsArrays,
:SmallVectors,
:SortedSets,
Expand Down
5 changes: 4 additions & 1 deletion NDTensors/src/arraystorage/diagonalarray/storage/contract.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using .SparseArrayInterface: densearray
using .DiagonalArrays: DiagIndex, diaglength

# TODO: Move to a different file.
Unwrap.parenttype(::Type{<:DiagonalArray{<:Any,<:Any,P}}) where {P} = P

Expand Down Expand Up @@ -99,7 +102,7 @@ function contract!(
coffset += ii * custride[i]
end
c = zero(eltype(C))
for j in 1:DiagonalArrays.diaglength(A)
for j in 1:diaglength(A)
# With α == 0 && β == 1
C[cstart + j * c_cstride + coffset] +=
A[DiagIndex(j)] * B[bstart + j * b_cstride + boffset]
Expand Down
5 changes: 3 additions & 2 deletions NDTensors/src/dims.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .DiagonalArrays: DiagonalArrays

export dense, dims, dim, mindim, diaglength

# dim and dims are used in the Tensor interface, overload
Expand Down Expand Up @@ -26,7 +28,7 @@ mindim(inds::Tuple) = minimum(dims(inds))

mindim(::Tuple{}) = 1

diaglength(inds::Tuple) = mindim(inds)
DiagonalArrays.diaglength(inds::Tuple) = mindim(inds)

"""
dim_to_strides(ds)
Expand Down Expand Up @@ -94,4 +96,3 @@ dim(T::Tensor) = dim(inds(T))
dim(T::Tensor, i::Int) = dim(inds(T), i)
maxdim(T::Tensor) = maxdim(inds(T))
mindim(T::Tensor) = mindim(inds(T))
diaglength(T::Tensor) = mindim(T)
5 changes: 4 additions & 1 deletion NDTensors/src/lib/DiagonalArrays/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
A Julia `DiagonalArray` type.

````julia
using NDTensors.DiagonalArrays: DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, isdiagindex
using NDTensors.DiagonalArrays:
DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, diaglength, isdiagindex
using Test

function main()
d = DiagonalMatrix([1.0, 2.0, 3.0])
@test eltype(d) == Float64
@test diaglength(d) == 3
@test size(d) == (3, 3)
@test d[1, 1] == 1
@test d[2, 2] == 2
Expand All @@ -17,6 +19,7 @@ function main()

d = DiagonalArray([1.0, 2.0, 3.0], 3, 4, 5)
@test eltype(d) == Float64
@test diaglength(d) == 3
@test d[1, 1, 1] == 1
@test d[2, 2, 2] == 2
@test d[3, 3, 3] == 3
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/lib/DiagonalArrays/examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
# A Julia `DiagonalArray` type.

using NDTensors.DiagonalArrays:
DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, isdiagindex
DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, diaglength, isdiagindex
using Test

function main()
d = DiagonalMatrix([1.0, 2.0, 3.0])
@test eltype(d) == Float64
@test diaglength(d) == 3
@test size(d) == (3, 3)
@test d[1, 1] == 1
@test d[2, 2] == 2
Expand All @@ -17,6 +18,7 @@ function main()

d = DiagonalArray([1.0, 2.0, 3.0], 3, 4, 5)
@test eltype(d) == Float64
@test diaglength(d) == 3
@test d[1, 1, 1] == 1
@test d[2, 2, 2] == 2
@test d[3, 3, 3] == 3
Expand Down
9 changes: 7 additions & 2 deletions NDTensors/src/lib/DiagonalArrays/src/diaginterface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
using Compat: allequal

diaglength(a::AbstractArray{<:Any,0}) = 1

function diaglength(a::AbstractArray)
return minimum(size(a))
end

function isdiagindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
@boundscheck checkbounds(a, I)
return allequal(Tuple(I))
Expand All @@ -16,8 +22,7 @@ function diagstride(a::AbstractArray)
end

function diagindices(a::AbstractArray)
diaglength = minimum(size(a))
maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength), ndims(a)))]
maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength(a)), ndims(a)))]
return 1:diagstride(a):maxdiag
end

Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/lib/DiagonalArrays/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
14 changes: 11 additions & 3 deletions NDTensors/src/lib/DiagonalArrays/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test
using NDTensors.DiagonalArrays

@eval module $(gensym())
using Test: @test, @testset
using NDTensors.DiagonalArrays: DiagonalArrays
@testset "Test NDTensors.DiagonalArrays" begin
@testset "README" begin
@test include(
Expand All @@ -9,4 +9,12 @@ using NDTensors.DiagonalArrays
),
) isa Any
end
@testset "Basics" begin
using NDTensors.DiagonalArrays: diaglength
a = fill(1.0, 2, 3)
@test diaglength(a) == 2
a = fill(1.0)
@test diaglength(a) == 1
end
end
end
6 changes: 6 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/GradedAxes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module GradedAxes
include("groupsortperm.jl")
include("tensor_product.jl")
include("abstractgradedunitrange.jl")
include("gradedunitrange.jl")
end
104 changes: 104 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using BlockArrays:
BlockArrays,
Block,
BlockRange,
BlockedUnitRange,
blockaxes,
blockedrange,
blockfirsts,
blocklasts,
blocklengths,
findblock
using Dictionaries: Dictionary

# Fuse two symmetry labels
fuse(l1, l2) = error("Not implemented")

abstract type AbstractGradedUnitRange{T,G} <: AbstractUnitRange{Int} end

BlockArrays.blockedrange(a::AbstractGradedUnitRange) = error("Not implemented")
sectors(a::AbstractGradedUnitRange) = error("Not implemented")
scale_factor(a::AbstractGradedUnitRange) = error("Not implemented")

# BlockArrays block axis interface
BlockArrays.blockaxes(a::AbstractGradedUnitRange) = blockaxes(blockedrange(a))
Base.getindex(a::AbstractGradedUnitRange, b::Block{1}) = blockedrange(a)[b]
BlockArrays.blockfirsts(a::AbstractGradedUnitRange) = blockfirsts(blockedrange(a))
BlockArrays.blocklasts(a::AbstractGradedUnitRange) = blocklasts(blockedrange(a))
function BlockArrays.findblock(a::AbstractGradedUnitRange, k::Integer)
return findblock(blockedrange(a), k)
end

# Base axis interface
Base.getindex(a::AbstractGradedUnitRange, I::Integer) = blockedrange(a)[I]
Base.first(a::AbstractGradedUnitRange) = first(blockedrange(a))
Base.last(a::AbstractGradedUnitRange) = last(blockedrange(a))
Base.length(a::AbstractGradedUnitRange) = length(blockedrange(a))
Base.step(a::AbstractGradedUnitRange) = step(blockedrange(a))
Base.unitrange(b::AbstractGradedUnitRange) = first(b):last(b)

sector(a::AbstractGradedUnitRange, b::Block{1}) = sectors(a)[only(b.n)]
sector(a::AbstractGradedUnitRange, I::Integer) = sector(a, findblock(a, I))

# Tensor product, no sorting
function tensor_product(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange)
a = tensor_product(blockedrange(a1), blockedrange(a2))
sectors_a = map(Iterators.product(sectors(a1), sectors(a2))) do (l1, l2)
return fuse(scale_factor(a1) * l1, scale_factor(a2) * l2)
end
return gradedrange(a, vec(sectors_a))
end

function Base.show(io::IO, mimetype::MIME"text/plain", a::AbstractGradedUnitRange)
show(io, mimetype, sectors(a))
println(io)
println(io, "Scale factor = ", scale_factor(a))
return show(io, mimetype, blockedrange(a))
end

function blockmerge(a::AbstractGradedUnitRange, grouped_perm::Vector{Vector{Int}})
merged_sectors = map(group -> sector(a, Block(first(group))), grouped_perm)
lengths = blocklengths(a)
merged_lengths = map(group -> sum(@view(lengths[group])), grouped_perm)
return gradedrange(merged_sectors, merged_lengths)
end

# Sort and merge by the grade of the blocks.
function blockmergesort(a::AbstractGradedUnitRange)
grouped_perm = blockmergesortperm(a)
return blockmerge(a, grouped_perm)
end

# Get the permutation for sorting, then group by common elements.
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
function blockmergesortperm(a::AbstractGradedUnitRange)
return groupsortperm(sectors(a))
end

function sub_axis(a::AbstractGradedUnitRange, blocks)
a_sub = sub_axis(blockedrange(a), blocks)
sectors_sub = map(b -> sector(a, b), Indices(blocks))
return AbstractGradedUnitRange(a_sub, sectors_sub)
end

function fuse(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange)
a = tensor_product(a1, a2)
return blockmergesort(a)
end

## TODO: Add this back.
## # Slicing
## ## using BlockArrays: BlockRange, _BlockedUnitRange
## Base.@propagate_inbounds function Base.getindex(
## b::AbstractGradedUnitRange, KR::BlockRange{1}
## )
## cs = blocklasts(b)
## isempty(KR) && return _BlockedUnitRange(1, cs[1:0])
## K, J = first(KR), last(KR)
## k, j = Integer(K), Integer(J)
## bax = blockaxes(b, 1)
## @boundscheck K in bax || throw(BlockBoundsError(b, K))
## @boundscheck J in bax || throw(BlockBoundsError(b, J))
## K == first(bax) && return _BlockedUnitRange(first(b), cs[k:j])
## return _BlockedUnitRange(cs[k - 1] + 1, cs[k:j])
## end
23 changes: 23 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using BlockArrays: BlockArrays, BlockedUnitRange, blockedrange

struct GradedUnitRange{T,G,S} <: AbstractGradedUnitRange{T,G}
blockedrange::BlockedUnitRange{T}
sectors::Vector{G}
scale_factor::S
end

BlockArrays.blockedrange(s::GradedUnitRange) = s.blockedrange
sectors(s::GradedUnitRange) = s.sectors
scale_factor(s::GradedUnitRange) = s.scale_factor

function gradedrange(sectors::Vector, blocklengths::Vector{Int}, scale_factor=1)
return GradedUnitRange(blockedrange(blocklengths), sectors, scale_factor)
end

function gradedrange(sectors_lengths::Vector{<:Pair{<:Any,Int}}, scale_factor=1)
return gradedrange(first.(sectors_lengths), last.(sectors_lengths), scale_factor)
end

function gradedrange(a::BlockedUnitRange, sectors::Vector, scale_factor=1)
return GradedUnitRange(a, sectors, scale_factor)
end
15 changes: 15 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/groupsortperm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using BlockArrays: BlockVector
using SplitApplyCombine: groupcount

function groupsorted(v)
return groupcount(identity, v)
end

# Get the permutation for sorting, then group by common elements.
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
function groupsortperm(v)
perm = sortperm(v)
v_sorted = @view v[perm]
group_lengths = groupsorted(v_sorted)
return blocks(BlockVector(perm, collect(group_lengths)))
end
14 changes: 14 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/tensor_product.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using BlockArrays: BlockedUnitRange, blocks

# https://en.wikipedia.org/wiki/Tensor_product
# https://github.com/KeitaNakamura/Tensorial.jl
tensor_product(a1, a2, a3, as...) = foldl(tensor_product, (a1, a2, a3, as...))
tensor_product(a1, a2) = error("Not implemented for $(typeof(a1)) and $(typeof(a2)).")

function tensor_product(a1::Base.OneTo, a2::Base.OneTo)
return Base.OneTo(length(a1) * length(a2))
end

function tensor_product(a1::BlockedUnitRange, a2::BlockedUnitRange)
return blockedrange(prod.(length, vec(collect(Iterators.product(blocks.((a1, a2))...)))))
end
2 changes: 2 additions & 0 deletions NDTensors/src/lib/GradedAxes/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
77 changes: 77 additions & 0 deletions NDTensors/src/lib/GradedAxes/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
@eval module $(gensym())
using NDTensors.BlockArrays: Block, blocklength, blocklengths, findblock
using NDTensors.GradedAxes:
GradedAxes,
blockmerge,
blockmergesortperm,
fuse,
gradedrange,
sector,
sectors,
tensor_product
using Test: @test

struct U1
dim::Int
end
Base.isless(l1::U1, l2::U1) = isless(l1.dim, l2.dim)
Base.:*(c::Int, l::U1) = U1(c * l.dim)
GradedAxes.fuse(l1::U1, l2::U1) = U1(l1.dim + l2.dim)

a = gradedrange([U1(0), U1(1)], [2, 3])
@test a isa GradedAxes.GradedUnitRange
@test a == gradedrange([U1(0) => 2, U1(1) => 3])
@test length(a) == 5
@test a == 1:5
@test a[Block(1)] == 1:2
@test a[Block(2)] == 3:5
@test blocklength(a) == 2 # Number of sectors
@test blocklengths(a) == [2, 3]
# TODO: Maybe rename to `labels`, `label`.
@test sectors(a) == [U1(0), U1(1)]
@test sector(a, Block(1)) == U1(0)
@test sector(a, Block(2)) == U1(1)
@test findblock(a, 1) == Block(1)
@test findblock(a, 2) == Block(1)
@test findblock(a, 3) == Block(2)
@test findblock(a, 4) == Block(2)
@test findblock(a, 5) == Block(2)
@test sector(a, 1) == U1(0)
@test sector(a, 2) == U1(0)
@test sector(a, 3) == U1(1)
@test sector(a, 4) == U1(1)
@test sector(a, 5) == U1(1)

# Naive tensor product, no sorting and merging
a2 = tensor_product(a, a)
@test a2 isa GradedAxes.GradedUnitRange
@test a2 == gradedrange([U1(0) => 4, U1(1) => 6, U1(1) => 6, U1(2) => 9])
@test length(a2) == 25
@test a2 == 1:25
@test blocklength(a2) == 4
@test blocklengths(a2) == [4, 6, 6, 9]
@test sectors(a2) == [U1(0), U1(1), U1(1), U1(2)]
@test sector(a2, Block(1)) == U1(0)
@test sector(a2, Block(2)) == U1(1)
@test sector(a2, Block(3)) == U1(1)
@test sector(a2, Block(4)) == U1(2)

# Fusion tensor product, with sorting and merging
a2 = fuse(a, a)
@test a2 isa GradedAxes.GradedUnitRange
@test a2 == gradedrange([U1(0) => 4, U1(1) => 12, U1(2) => 9])
@test length(a2) == 25
@test a2 == 1:25
@test blocklength(a2) == 3
@test blocklengths(a2) == [4, 12, 9]
@test sectors(a2) == [U1(0), U1(1), U1(2)]
@test sector(a2, Block(1)) == U1(0)
@test sector(a2, Block(2)) == U1(1)
@test sector(a2, Block(3)) == U1(2)

# The partitioned permutation needed to sort
# and merge an unsorted graded space
perm_a = blockmergesortperm(tensor_product(a, a))
@test perm_a == [[1], [2, 3], [4]]
@test blockmerge(tensor_product(a, a), perm_a) == fuse(a, a)
end
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module SparseArrayInterface
include("densearray.jl")
include("interface.jl")
include("interface_optional.jl")
include("indexing.jl")
Expand Down
Loading
Loading