Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] BlockSparseArray contract, QR, and Hermitian eigendecomposition #1247

Merged
merged 21 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
Expand Down Expand Up @@ -50,6 +51,7 @@ LinearAlgebra = "1.6"
Random = "1.6"
Requires = "1.1"
SimpleTraits = "0.9.4"
SparseArrays = "1.6"
SplitApplyCombine = "1.2.2"
StaticArrays = "0.12, 1.0"
Strided = "2"
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using BlockArrays
using Compat
using Dictionaries
using SplitApplyCombine
using LinearAlgebra: Hermitian, Transpose

using BlockArrays: block

Expand All @@ -13,9 +14,13 @@ include("base.jl")
include("axes.jl")
include("abstractarray.jl")
include("permuteddimsarray.jl")
include("hermitian.jl")
include("transpose.jl")
include("blockarrays.jl")
# TODO: Split off into `NDSparseArrays` module.
include("sparsearray.jl")
include("blocksparsearray.jl")
include("allocate_output.jl")
include("subarray.jl")
include("broadcast.jl")
include("fusedims.jl")
Expand Down
120 changes: 120 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/allocate_output.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#############################################################################
# Generic
#############################################################################

function output_type(f, args::Type...)
# TODO: Is this good to use here?
# Seems best for `Number` subtypes, maybe restrict to that here.
return typeof(f(zero.(args)...))
# return Base.promote_op(f, args...)
end

function output_type(f::Function, as::Type{<:AbstractArray}...)
@assert allequal(ndims.(as))
elt = output_type(f, eltype.(as)...)
n = ndims(first(as))
# TODO: Generalize this to GPU arrays!
return Array{elt,n}
end

#############################################################################
# AbstractArray
#############################################################################

# Related to:
# https://github.com/JuliaLang/julia/issues/18161
# https://github.com/JuliaLang/julia/issues/25107
# https://github.com/JuliaLang/julia/issues/11557
abstract type AbstractArrayStructure{ElType,Axes} end

# TODO: Make this backwards compatible.
# TODO: Add a default for `eltype`.
@kwdef struct ArrayStructure{ElType,Axes} <: AbstractArrayStructure{ElType,Axes}
eltype::ElType
axes::Axes
end

function output_eltype(::typeof(map_nonzeros), fmap, as::Type{<:AbstractArray}...)
return output_type(fmap, eltype.(as)...)
end

function output_eltype(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
# TODO: Compute based on runtime information?
return output_eltype(f, fmap, typeof.(as)...)
end

function output_axes(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
# TODO: Make this more sophisticated, BlockSparseArrays
# may have different block shapes.
@assert allequal(axes.(as))
return axes(first(as))
end

# Defaults to `ArrayStructure`.
# Maybe define a `default_output_structure`?
function output_structure(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
return ArrayStructure(;
eltype=output_eltype(f, fmap, as...), axes=output_axes(f, fmap, as...)
)
end

# Defaults to `ArrayStructure`.
# Maybe define a `default_output_type`?
function output_type(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
return Array
end

# Allocate an array with uninitialized/undefined memory
# according the array type and structure (for example the
# size or axes).
function allocate(arraytype::Type{<:AbstractArray}, structure)
# TODO: Use `set_eltype`.
return arraytype{structure.eltype}(undef, structure.axes)
end

function allocate_zeros(arraytype::Type{<:AbstractArray}, structure)
a = allocate(arraytype, structure)
# Assumes `arraytype` is mutable.
# TODO: Use `zeros!!` or `zerovector!!` from VectorInterface.jl?
map!(Returns(false), a, a)
return a
end

function allocate_output(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
return allocate_zeros(output_type(f, fmap, as...), output_structure(f, fmap, as...))
end

#############################################################################
# SparseArray
#############################################################################

# TODO: Maybe store nonzero locations?
# TODO: Make this backwards compatible.
# TODO: Add a default for `eltype` and `zero`.
@kwdef struct SparseArrayStructure{ElType,Axes,Zero} <: AbstractArrayStructure{ElType,Axes}
eltype::ElType
axes::Axes
zero::Zero
end

function allocate(arraytype::Type{<:SparseArray}, structure::SparseArrayStructure)
# TODO: Use `set_eltype`.
return arraytype{structure.eltype}(structure.axes, structure.zero)
end

function output_structure(f::typeof(map_nonzeros), fmap, as::SparseArrayLike...)
return SparseArrayStructure(;
eltype=output_eltype(f, fmap, as...),
axes=output_axes(f, fmap, as...),
zero=output_zero(f, fmap, as...),
)
end

function output_type(f::typeof(map_nonzeros), fmap, as::SparseArrayLike...)
return SparseArray
end

function output_zero(f::typeof(map_nonzeros), fmap, as::SparseArrayLike...)
# TODO: Check they are all the same, update for now `axes`?
return first(as).zero
end
48 changes: 45 additions & 3 deletions NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ blocktype(a::BlockSparseArray{<:Any,<:Any,A}) where {A} = A
# TODO: Use `SetParameters`.
set_ndims(::Type{<:Array{T}}, n) where {T} = Array{T,n}

function nonzero_blockkeys(a::BlockSparseArray)
# TODO: Move to `AbstractArray` file.
function nonzero_blockkeys(a::AbstractArray)
return map(Block ∘ Tuple, collect(nonzero_keys(blocks(a))))
end

Expand Down Expand Up @@ -60,11 +61,18 @@ struct BlockZero{Axes}
end

function (f::BlockZero)(
arraytype::Type{<:AbstractArray{T,N}}, I::CartesianIndex{N}
) where {T,N}
arraytype::Type{<:AbstractArray{<:Any,N}}, I::CartesianIndex{N}
) where {N}
# TODO: Make sure this works for sparse or block sparse blocks, immutable
# blocks, diagonal blocks, etc.!
return fill!(arraytype(undef, block_size(f.axes, Block(Tuple(I)))), false)
end

# Fallback so that `SparseArray` with scalar elements works.
function (f::BlockZero)(blocktype::Type{<:Number}, I::CartesianIndex)
return zero(blocktype)
end

# Fallback to Array if it is abstract
function (f::BlockZero)(
arraytype::Type{AbstractArray{T,N}}, I::CartesianIndex{N}
Expand Down Expand Up @@ -98,6 +106,22 @@ function BlockSparseArray{T,N,B}(
return BlockSparseArray{T,N,B}(undef, blocks, axes)
end

## struct BlockSparseArray{
## T,N,A<:AbstractArray{T,N},Blocks<:SparseArray{A,N},Axes<:NTuple{N,AbstractUnitRange{Int}}
## } <: AbstractBlockArray{T,N}
## blocks::Blocks
## axes::Axes
## end

function BlockSparseArray(a::SparseArray, axes::Tuple{Vararg{AbstractUnitRange}})
A = eltype(a)
T = eltype(A)
N = ndims(a)
Blocks = typeof(a)
Axes = typeof(axes)
return BlockSparseArray{T,N,A,Blocks,Axes}(a, axes)
end

function BlockSparseArray(
blockdata::Dictionary{<:Block{N}}, axes::Tuple{Vararg{AbstractUnitRange{Int},N}}
) where {N}
Expand Down Expand Up @@ -234,6 +258,12 @@ function Base.getindex(block_arr::BlockSparseArray{T,N}, i::Vararg{Integer,N}) w
return v
end

# Fixes ambiguity error.
# TODO: Is this needed?
function Base.getindex(block_arr::BlockSparseArray{<:Any,0})
return blocks(block_arr)[CartesianIndex()][]
end

function Base.permutedims!(a_src::BlockSparseArray, a_dest::BlockSparseArray, perm)
copyto!(a_src, PermutedDimsArray(a_dest, perm))
return a_src
Expand All @@ -252,6 +282,13 @@ function BlockArrays.blocks(
return PermutedDimsArray(blocks(parent(a)), perm(a))
end

# TODO: Make `PermutedBlockSparseArray`.
function BlockArrays.blocks(
a::Hermitian{<:Any,<:BlockSparseArray}
)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
return Hermitian(blocks(parent(a)))
end

# TODO: Make `PermutedBlockSparseArray`.
function Base.zero(a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:BlockSparseArray})
return BlockSparseArray(zero(blocks(a)), axes(a))
Expand All @@ -270,3 +307,8 @@ function Base.map(f, as::BlockSparseArray...)
@assert allequal(axes.(as))
return BlockSparseArray(map(f, blocks.(as)...), axes(first(as)))
end

function Base.map!(f, a_dest::BlockSparseArray, as::BlockSparseArray...)
@assert allequal(axes.((a_dest, as...)))
return BlockSparseArray(map!(f, blocks(a_dest), blocks.(as)...), axes(a_dest))
end
50 changes: 50 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,20 @@ function _broadcast(f, as::AbstractArray...)
return map(f, as...)
end

function _broadcast!(f, a_dest::AbstractArray, as::AbstractArray...)
if !preserves_zeros(f, as...)
error("Broadcasting functions that don't preserve zeros isn't supported yet.")
end
# TODO: Use `map_nonzeros!` here?
return map!(f, a_dest, as...)
end

isnumber(x::Number) = true
isnumber(x) = false

function flatten_numbers(f, args)
# TODO: Is there a simpler way to implement this?
# This doesn't play well with `Base.promote_op`.
function flattened_f(flattened_args...)
j = 0
unflattened_args = ntuple(length(args)) do i
Expand All @@ -50,3 +59,44 @@ function Base.copy(bc::Broadcasted{<:BlockSparseStyle})
f, args = flatten_numbers(bcf.f, bcf.args)
return _broadcast(f, args...)
end

function Base.copyto!(a_dest::BlockSparseArray, bc::Broadcasted{<:BlockSparseStyle})
bcf = flatten(bc)
f, args = flatten_numbers(bcf.f, bcf.args)
return _broadcast!(f, a_dest, args...)
end

## # Special algebra cases
## struct LeftMul{C}
## c::C
## end
## (f::LeftMul)(x) = f.c * x
##
## struct RightMul{C}
## c::C
## end
## (f::RightMul)(x) = x * f.c
##
## # 2 .* a
## function Base.copy(bc::Broadcasted{<:BlockSparseStyle,<:Any,typeof(*),<:Tuple{<:Number,<:AbstractArray}})
## # TODO: Use `map_nonzeros`.
## return map(LeftMul(bc.args[1]), bc.args[2])
## end
##
## # a .* 2
## function Base.copy(bc::Broadcasted{<:BlockSparseStyle,<:Any,typeof(*),<:Tuple{<:AbstractArray,<:Number}})
## # TODO: Use `map_nonzeros`.
## return map(RightMul(bc.args[2]), bc.args[1])
## end
##
## # a ./ 2
## function Base.copy(bc::Broadcasted{<:BlockSparseStyle,<:Any,typeof(/),<:Tuple{<:AbstractArray,<:Number}})
## # TODO: Use `map_nonzeros`.
## return map(RightMul(inv(bc.args[2])), bc.args[1])
## end
##
## # a .+ b
## function Base.copy(bc::Broadcasted{<:BlockSparseStyle,<:Any,<:Union{typeof(+),typeof(-)},<:Tuple{<:AbstractArray,<:AbstractArray}})
## # TODO: Use `map_nonzeros`.
## return map(+, bc.args...)
## end
3 changes: 3 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/hermitian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TODO: This needs to be done more carefully by
# grabbing upper or lower triangles of the parent matrix.
nonzero_keys(a::Hermitian) = nonzero_keys(parent(a))
20 changes: 18 additions & 2 deletions NDTensors/src/BlockSparseArrays/src/sparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@ struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
zero::Zero
end

default_zero() = (eltype, I) -> zero(eltype)

function SparseArray{T}(size::Tuple{Vararg{Integer}}, zero=default_zero()) where {T}
return SparseArray(Dictionary{CartesianIndex{length(size)},T}(), size, zero)
end
SparseArray{T}(size::Integer...) where {T} = SparseArray{T}(size)

function SparseArray{T}(
axes::Tuple{Vararg{AbstractUnitRange}}, zero=default_zero()
) where {T}
return SparseArray{T}(length.(axes), zero)
end
SparseArray{T}(axes::AbstractUnitRange...) where {T} = SparseArray{T}(length.(axes))

Base.size(a::SparseArray) = a.dims

function Base.setindex!(a::SparseArray{T,N}, v, I::CartesianIndex{N}) where {T,N}
Expand Down Expand Up @@ -80,8 +94,10 @@ function map_nonzeros!(f, a_dest::AbstractArray, as::SparseArrayLike...)
end

function map_nonzeros(f, as::SparseArrayLike...)
@assert allequal(axes.(as))
a_dest = zero(first(as))
## @assert allequal(axes.(as))
# Preserves the element type:
# a_dest = zero(first(as))
a_dest = allocate_output(map_nonzeros, f, as...)
map!(f, a_dest, as...)
return a_dest
end
Expand Down
6 changes: 6 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/transpose.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
function nonzero_keys(a::Transpose)
return (
CartesianIndex(reverse(Tuple(parent_index))) for
parent_index in nonzero_keys(parent(a))
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
)
end
2 changes: 2 additions & 0 deletions NDTensors/src/BlockSparseArrays/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
25 changes: 24 additions & 1 deletion NDTensors/src/BlockSparseArrays/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using BlockArrays: BlockArrays, blockedrange, blocksize
using BlockArrays: BlockArrays, BlockRange, blocksize
using NDTensors: contract
using NDTensors.BlockSparseArrays:
BlockSparseArrays, BlockSparseArray, gradedrange, nonzero_blockkeys, fusedims
using ITensors: QN
Expand Down Expand Up @@ -56,4 +57,26 @@ using ITensors: QN
# need to fix that.
@test_broken length(nonzero_blockkeys(B_fused)) == 2
end
@testset "contract" begin
function randn_even_blocks!(a)
for b in BlockRange(a)
if iseven(sum(b.n))
a[b] = randn(eltype(a), size(@view(a[b])))
end
end
end

d1, d2, d3, d4 = [2, 3], [3, 4], [4, 5], [5, 6]
elt = Float64
a1 = BlockSparseArray{elt}(d1, d2, d3)
randn_even_blocks!(a1)
a2 = BlockSparseArray{elt}(d2, d4)
randn_even_blocks!(a2)
a_dest, labels_dest = contract(a1, (1, -1, 2), a2, (-1, 3))
@show labels_dest == (1, 2, 3)

# TODO: Output `labels_dest` as well.
a_dest_dense = contract(Array(a1), (1, -1, 2), Array(a2), (-1, 3))
@show a_dest ≈ a_dest_dense
end
end
Loading
Loading