Skip to content

Commit

Permalink
Merge branch 'main' into non-abelian_fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Jun 14, 2024
2 parents 19df9ea + 87ec605 commit a86e67e
Show file tree
Hide file tree
Showing 54 changed files with 647 additions and 1,164 deletions.
4 changes: 2 additions & 2 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.24"
version = "0.3.29"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down Expand Up @@ -56,7 +56,7 @@ AMDGPU = "0.9"
Accessors = "0.1.33"
Adapt = "3.7, 4"
ArrayLayouts = "1.4"
BlockArrays = "0.16"
BlockArrays = "1"
CUDA = "5"
Compat = "4.9"
cuTENSOR = "2"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
module BlockSparseArraysGradedAxesExt
using BlockArrays: AbstractBlockVector, Block, BlockedUnitRange, blocks
using BlockArrays:
AbstractBlockVector,
AbstractBlockedUnitRange,
Block,
BlockIndexRange,
blockedrange,
blocks
using ..BlockSparseArrays:
BlockSparseArrays,
AbstractBlockSparseArray,
AbstractBlockSparseMatrix,
BlockSparseArray,
BlockSparseMatrix,
BlockSparseVector,
block_merge
using ...GradedAxes:
GradedUnitRange,
GradedAxes,
AbstractGradedUnitRange,
OneToOne,
blockmergesortperm,
blocksortperm,
Expand All @@ -23,11 +31,13 @@ using ...TensorAlgebra:
# TODO: Make a `ReduceWhile` library.
include("reducewhile.jl")

TensorAlgebra.FusionStyle(::GradedUnitRange) = SectorFusion()
TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion()

# TODO: Need to implement this! Will require implementing
# `block_merge(a::AbstractUnitRange, blockmerger::BlockedUnitRange)`.
function BlockSparseArrays.block_merge(a::GradedUnitRange, blockmerger::BlockedUnitRange)
function BlockSparseArrays.block_merge(
a::AbstractGradedUnitRange, blockmerger::AbstractBlockedUnitRange
)
return a
end

Expand Down Expand Up @@ -75,6 +85,44 @@ function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix})
return dual.(reverse(axes(a')))
end

# TODO: Delete this definition in favor of the one in
# GradedAxes once https://github.com/JuliaArrays/BlockArrays.jl/pull/405 is merged.
# TODO: Make a special definition for `BlockedVector{<:Block{1}}` in order
# to merge blocks.
function GradedAxes.blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
)
# Without converting `indices` to `Vector`,
# mapping `indices` outputs a `BlockVector`
# which is harder to reason about.
blocks = map(index -> a[index], Vector(indices))
# We pass `length.(blocks)` to `mortar` in order
# to pass block labels to the axes of the output,
# if they exist. This makes it so that
# `only(axes(a[indices])) isa `GradedUnitRange`
# if `a isa `GradedUnitRange`, for example.
# TODO: Remove `unlabel` once `BlockArray` axes
# type is generalized in BlockArrays.jl.
# TODO: Support using `BlockSparseVector`, need
# to make more `BlockSparseArray` constructors.
return BlockSparseArray(blocks, (blockedrange(length.(blocks)),))
end

# This definition is only needed since calls like
# `a[[Block(1), Block(2)]]` where `a isa AbstractGradedUnitRange`
# returns a `BlockSparseVector` instead of a `BlockVector`
# due to limitations in the `BlockArray` type not allowing
# axes with non-Int element types.
# TODO: Remove this once that issue is fixed,
# see https://github.com/JuliaArrays/BlockArrays.jl/pull/405.
using BlockArrays: BlockRange
using NDTensors.LabelledNumbers: label
function GradedAxes.blocklabels(a::BlockSparseVector)
return map(BlockRange(a)) do block
return label(blocks(a)[Int(block)])
end
end

# This is a temporary fix for `show` being broken for BlockSparseArrays
# with mixed dual and non-dual axes. This shouldn't be needed once
# GradedAxes is rewritten using BlockArrays v1.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using BlockArrays: Block, blockedrange, blocksize
using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
using NDTensors.GradedAxes:
GradedAxes, GradedUnitRange, UnitRangeDual, blocklabels, dual, gradedrange
GradedAxes, GradedOneTo, UnitRangeDual, blocklabels, dual, gradedrange
using NDTensors.LabelledNumbers: label
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: fusedims, splitdims
Expand Down Expand Up @@ -35,15 +35,34 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
for b in (a + a, 2 * a)
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
@test nstored(b) == 32
@test block_nstored(b) == 2
# TODO: Have to investigate why this fails
# on Julia v1.6, or drop support for v1.6.
for i in 1:ndims(a)
@test axes(b, i) isa GradedUnitRange
@test axes(b, i) isa GradedOneTo
end
@test label(axes(b, 1)[Block(1)]) == U1(0)
@test label(axes(b, 1)[Block(2)]) == U1(1)
@test Array(b) isa Array{elt}
@test Array(b) == b
@test 2 * Array(a) == b
end

# Test mixing graded axes and dense axes
# in addition/broadcasting.
for b in (a + Array(a), Array(a) + a)
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
# TODO: Fix this for `BlockedArray`.
@test_broken nstored(b) == 256
# TODO: Fix this for `BlockedArray`.
@test_broken block_nstored(b) == 16
for i in 1:ndims(a)
@test axes(b, i) isa BlockedOneTo{Int}
end
@test Array(a) isa Array{elt}
@test Array(a) == a
@test 2 * Array(a) == b
Expand All @@ -55,7 +74,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test nstored(b) == 2
@test block_nstored(b) == 2
for i in 1:ndims(a)
@test axes(b, i) isa GradedUnitRange
@test axes(b, i) isa GradedOneTo
end
@test label(axes(b, 1)[Block(1)]) == U1(0)
@test label(axes(b, 1)[Block(2)]) == U1(1)
Expand All @@ -72,7 +91,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
# TODO: Once block merging is implemented, this should
# be the real test.
for ax in axes(m)
@test ax isa GradedUnitRange
@test ax isa GradedOneTo
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)]
Expand All @@ -95,37 +114,48 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
end
@testset "dual axes" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(dual(r), r)
@views for b in [Block(1, 1), Block(2, 2)]
a[b] = randn(elt, size(a[b]))
end
# TODO: Define and use `isdual` here.
@test axes(a, 1) isa UnitRangeDual
@test axes(a, 2) isa GradedUnitRange
@test !(axes(a, 2) isa UnitRangeDual)
a_dense = Array(a)
@test eachindex(a) == CartesianIndices(size(a))
for I in eachindex(a)
@test a[I] == a_dense[I]
end
@test axes(a') == dual.(reverse(axes(a)))
# TODO: Define and use `isdual` here.
@test axes(a', 1) isa UnitRangeDual
@test axes(a', 2) isa GradedUnitRange
@test !(axes(a', 2) isa UnitRangeDual)
@test isnothing(show(devnull, MIME("text/plain"), a))

# Check preserving dual in tensor algebra.
for b in (a + a, 2 * a, 3 * a - a)
@test Array(b) 2 * Array(a)
for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r)))
a = BlockSparseArray{elt}(ax...)
@views for b in [Block(1, 1), Block(2, 2)]
a[b] = randn(elt, size(a[b]))
end
# TODO: Define and use `isdual` here.
@test axes(b, 1) isa UnitRangeDual
@test axes(b, 2) isa GradedUnitRange
@test !(axes(b, 2) isa UnitRangeDual)
end
for dim in 1:ndims(a)
@test typeof(ax[dim]) === typeof(axes(a, dim))
end
@test @view(a[Block(1, 1)])[1, 1] == a[1, 1]
@test @view(a[Block(1, 1)])[2, 1] == a[2, 1]
@test @view(a[Block(1, 1)])[1, 2] == a[1, 2]
@test @view(a[Block(1, 1)])[2, 2] == a[2, 2]
@test @view(a[Block(2, 2)])[1, 1] == a[3, 3]
@test @view(a[Block(2, 2)])[2, 1] == a[4, 3]
@test @view(a[Block(2, 2)])[1, 2] == a[3, 4]
@test @view(a[Block(2, 2)])[2, 2] == a[4, 4]
@test @view(a[Block(1, 1)])[1:2, 1:2] == a[1:2, 1:2]
@test @view(a[Block(2, 2)])[1:2, 1:2] == a[3:4, 3:4]
a_dense = Array(a)
@test eachindex(a) == CartesianIndices(size(a))
for I in eachindex(a)
@test a[I] == a_dense[I]
end
@test axes(a') == dual.(reverse(axes(a)))
# TODO: Define and use `isdual` here.
@test typeof(axes(a', 1)) === typeof(dual(axes(a, 2)))
@test typeof(axes(a', 2)) === typeof(dual(axes(a, 1)))
@test isnothing(show(devnull, MIME("text/plain"), a))

# Check preserving dual in tensor algebra.
for b in (a + a, 2 * a, 3 * a - a)
@test Array(b) 2 * Array(a)
# TODO: Define and use `isdual` here.
for dim in 1:ndims(a)
@test typeof(axes(b, dim)) === typeof(axes(b, dim))
end
end

@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
end

# Test case when all axes are dual.
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module BlockSparseArraysTensorAlgebraExt
using BlockArrays: BlockedUnitRange
using BlockArrays: AbstractBlockedUnitRange
using ..BlockSparseArrays: AbstractBlockSparseArray, block_reshape
using ...GradedAxes: tensor_product
using ...TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion

TensorAlgebra.:(a1::BlockedUnitRange, a2::BlockedUnitRange) = tensor_product(a1, a2)
function TensorAlgebra.:(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
return tensor_product(a1, a2)
end

TensorAlgebra.FusionStyle(::BlockedUnitRange) = BlockReshapeFusion()
TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()

function TensorAlgebra.fusedims(
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
struct GenericBlockSlice{B,T<:Integer,I<:AbstractUnitRange{T}} <: AbstractUnitRange{T}
block::B
indices::I
end
BlockArrays.Block(bs::GenericBlockSlice{<:Block}) = bs.block
for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe_length)
@eval Base.$f(S::GenericBlockSlice) = Base.$f(S.indices)
end
Base.getindex(S::GenericBlockSlice, i::Integer) = getindex(S.indices, i)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
struct BlockIndices{B,T<:Integer,I<:AbstractVector{T}} <: AbstractVector{T}
blocks::B
indices::I
Expand Down Expand Up @@ -175,6 +190,13 @@ function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
end

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
function blockrange(axis::AbstractUnitRange, r::GenericBlockSlice)
return blockrange(axis, r.block)
end

function blockrange(a::AbstractUnitRange, r::BlockIndices)
return blockrange(a, r.blocks)
end
Expand Down Expand Up @@ -313,3 +335,34 @@ function blocked_cartesianindices(axes::Tuple, subaxes::Tuple, blocks)
return cartesianindices(subaxes, block)
end
end

function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N}
return view!(a, Tuple(index)...)
end
function view!(a::AbstractArray{<:Any,N}, index::Vararg{Block{1},N}) where {N}
blocks(a)[Int.(index)...] = blocks(a)[Int.(index)...]
return blocks(a)[Int.(index)...]
end

function view!(a::AbstractArray{<:Any,N}, index::BlockIndexRange{N}) where {N}
# TODO: Is there a better code pattern for this?
indices = ntuple(N) do dim
return Tuple(Block(index))[dim][index.indices[dim]]
end
return view!(a, indices...)
end
function view!(a::AbstractArray{<:Any,N}, index::Vararg{BlockIndexRange{1},N}) where {N}
b = view!(a, Block.(index)...)
r = map(index -> only(index.indices), index)
return @view b[r...]
end

using MacroTools: @capture
using NDTensors.SparseArrayDOKs: is_getindex_expr
macro view!(expr)
if !is_getindex_expr(expr)
error("@view must be used with getindex syntax (as `@view! a[i,j,...]`)")
end
@capture(expr, array_[indices__])
return :(view!($(esc(array)), $(esc.(indices)...)))
end
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using BlockArrays: BlockedUnitRange, BlockSlice
using BlockArrays: AbstractBlockedUnitRange, BlockSlice
using Base.Broadcast: Broadcast

function Broadcast.BroadcastStyle(arraytype::Type{<:BlockSparseArrayLike})
Expand All @@ -12,7 +12,7 @@ function Broadcast.BroadcastStyle(
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}},
<:Tuple{BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
},
},
)
Expand All @@ -25,8 +25,8 @@ function Broadcast.BroadcastStyle(
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{
BlockSlice{<:Any,<:BlockedUnitRange},
BlockSlice{<:Any,<:BlockedUnitRange},
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
Vararg{Any},
},
},
Expand All @@ -40,7 +40,7 @@ function Broadcast.BroadcastStyle(
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{Any,BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}},
<:Tuple{Any,BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
},
},
)
Expand Down
Loading

0 comments on commit a86e67e

Please sign in to comment.