Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] BlockSparseArray contract, QR, and Hermitian eigendecomposition #1247

Merged
merged 21 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
Expand Down
1 change: 1 addition & 0 deletions NDTensors/src/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include("permuteddimsarray.jl")
include("blockarrays.jl")
include("sparsearray.jl")
include("blocksparsearray.jl")
include("allocate_output.jl")
include("subarray.jl")
include("broadcast.jl")
include("fusedims.jl")
Expand Down
105 changes: 105 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/allocate_output.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#############################################################################
# Generic
#############################################################################

function output_type(f, args::Type...)
# TODO: Is this good to use here?
# Seems best for `Number` subtypes, maybe restrict to that here.
return Base.promote_op(f, args...)
end

#############################################################################
# AbstractArray
#############################################################################

# Related to:
# https://github.com/JuliaLang/julia/issues/18161
# https://github.com/JuliaLang/julia/issues/25107
# https://github.com/JuliaLang/julia/issues/11557
abstract type AbstractArrayStructure{ElType,Axes} end

# TODO: Make this backwards compatible.
# TODO: Add a default for `eltype`.
@kwdef struct ArrayStructure{ElType,Axes} <: AbstractArrayStructure{ElType,Axes}
eltype::ElType
axes::Axes
end

function output_eltype(::typeof(map_nonzeros), fmap, as::Type{<:AbstractArray}...)
return output_type(fmap, eltype.(as)...)
end

function output_eltype(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
# TODO: Compute based on runtime information?
return output_eltype(f, fmap, typeof.(as)...)
end

function output_axes(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
# TODO: Make this more sophisticated, BlockSparseArrays
# may have different block shapes.
@assert allequal(axes.(as))
return axes(first(as))
end

# Defaults to `ArrayStructure`.
# Maybe define a `default_output_structure`?
function output_structure(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
return ArrayStructure(eltype=output_eltype(f, fmap, as...), axes=output_axes(f, fmap, as...))
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
end

# Defaults to `ArrayStructure`.
# Maybe define a `default_output_type`?
function output_type(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
return Array
end

# Allocate an array with uninitialized/undefined memory
# according the array type and structure (for example the
# size or axes).
function allocate(arraytype::Type{<:AbstractArray}, structure)
# TODO: Use `set_eltype`.
return arraytype{structure.eltype}(undef, structure.axes)
end

function allocate_zeros(arraytype::Type{<:AbstractArray}, structure)
a = allocate(arraytype, structure)
# Assumes `arraytype` is mutable.
# TODO: Use `zeros!!` or `zerovector!!` from VectorInterface.jl?
map!(Returns(false), a, a)
return a
end

function allocate_output(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
return allocate_zeros(output_type(f, fmap, as...), output_structure(f, fmap, as...))
end

#############################################################################
# SparseArray
#############################################################################

# TODO: Maybe store nonzero locations?
# TODO: Make this backwards compatible.
# TODO: Add a default for `eltype` and `zero`.
@kwdef struct SparseArrayStructure{ElType,Axes,Zero} <: AbstractArrayStructure{ElType,Axes}
eltype::ElType
axes::Axes
zero::Zero
end

function allocate(arraytype::Type{<:SparseArray}, structure::SparseArrayStructure)
# TODO: Use `set_eltype`.
return arraytype{structure.eltype}(structure.axes, structure.zero)
end

function output_structure(f::typeof(map_nonzeros), fmap, as::SparseArrayLike...)
return SparseArrayStructure(eltype=output_eltype(f, fmap, as...), axes=output_axes(f, fmap, as...), zero=output_zero(f, fmap, as...))
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
end

function output_type(f::typeof(map_nonzeros), fmap, as::SparseArrayLike...)
return SparseArray
end

function output_zero(f::typeof(map_nonzeros), fmap, as::SparseArrayLike...)
# TODO: Check they are all the same, update for now `axes`?
return first(as).zero
end
24 changes: 22 additions & 2 deletions NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,20 @@ struct BlockZero{Axes}
end

function (f::BlockZero)(
arraytype::Type{<:AbstractArray{T,N}}, I::CartesianIndex{N}
) where {T,N}
arraytype::Type{<:AbstractArray{<:Any,N}}, I::CartesianIndex{N}
) where {N}
# TODO: Make sure this works for sparse or block sparse blocks, immutable
# blocks, diagonal blocks, etc.!
return fill!(arraytype(undef, block_size(f.axes, Block(Tuple(I)))), false)
end

# Fallback so that `SparseArray` with scalar elements works.
function (f::BlockZero)(
blocktype::Type{<:Number}, I::CartesianIndex
)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
return zero(blocktype)
end

# Fallback to Array if it is abstract
function (f::BlockZero)(
arraytype::Type{AbstractArray{T,N}}, I::CartesianIndex{N}
Expand Down Expand Up @@ -234,6 +243,12 @@ function Base.getindex(block_arr::BlockSparseArray{T,N}, i::Vararg{Integer,N}) w
return v
end

# Fixes ambiguity error.
# TODO: Is this needed?
function Base.getindex(block_arr::BlockSparseArray{<:Any,0})
return blocks(block_arr)[CartesianIndex()][]
end

function Base.permutedims!(a_src::BlockSparseArray, a_dest::BlockSparseArray, perm)
copyto!(a_src, PermutedDimsArray(a_dest, perm))
return a_src
Expand Down Expand Up @@ -270,3 +285,8 @@ function Base.map(f, as::BlockSparseArray...)
@assert allequal(axes.(as))
return BlockSparseArray(map(f, blocks.(as)...), axes(first(as)))
end

function Base.map!(f, a_dest::BlockSparseArray, as::BlockSparseArray...)
@assert allequal(axes.((a_dest, as...)))
return BlockSparseArray(map!(f, blocks(a_dest), blocks.(as)...), axes(a_dest))
end
15 changes: 15 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ function _broadcast(f, as::AbstractArray...)
return map(f, as...)
end

function _broadcast!(f, a_dest::AbstractArray, as::AbstractArray...)
if !preserves_zeros(f, as...)
error("Broadcasting functions that don't preserve zeros isn't supported yet.")
end
# TODO: Use `map_nonzeros!` here?
return map!(f, a_dest, as...)
end

isnumber(x::Number) = true
isnumber(x) = false

Expand All @@ -50,3 +58,10 @@ function Base.copy(bc::Broadcasted{<:BlockSparseStyle})
f, args = flatten_numbers(bcf.f, bcf.args)
return _broadcast(f, args...)
end

function Base.copyto!(a_dest::BlockSparseArray, bc::Broadcasted{<:BlockSparseStyle})
bcf = flatten(bc)
f, args = flatten_numbers(bcf.f, bcf.args)
return _broadcast!(f, a_dest, args...)
return error("Not implemented")
end
61 changes: 59 additions & 2 deletions NDTensors/src/BlockSparseArrays/src/sparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
zero::Zero
end

default_zero() = (eltype, I) -> zero(eltype)

SparseArray{T}(size::Tuple{Vararg{Integer}}, zero=default_zero()) where {T} = SparseArray(Dictionary{CartesianIndex{length(size)},T}(), size, zero)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
SparseArray{T}(size::Integer...) where {T} = SparseArray{T}(size)

SparseArray{T}(axes::Tuple{Vararg{AbstractUnitRange}}, zero=default_zero()) where {T} = SparseArray{T}(length.(axes), zero)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
SparseArray{T}(axes::AbstractUnitRange...) where {T} = SparseArray{T}(length.(axes))

Base.size(a::SparseArray) = a.dims

function Base.setindex!(a::SparseArray{T,N}, v, I::CartesianIndex{N}) where {T,N}
Expand Down Expand Up @@ -79,9 +87,58 @@ function map_nonzeros!(f, a_dest::AbstractArray, as::SparseArrayLike...)
return a_dest
end

## function output_type(f, args::Type...)
## # TODO: Is this good to use here?
## # Seems best for `Number` subtypes, maybe restrict to that here.
## return Base.promote_op(f, args...)
## end
##
## function output_eltype(::typeof(map_nonzeros), fmap, as::Type{<:AbstractArray}...)
## return output_type(fmap, eltype.(as)...)
## end
##
## function output_eltype(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
## return output_eltype(f, fmap, typeof.(as)...)
## end
##
## function output_structure(f::typeof(map_nonzeros), fmap, as::SparseArray...)
## end
##
## function output_structure(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
## return ArrayStructure(; eltype=output_eltype(f, fmap, as...), axes=output_axes(f, fmap, as...))
## end
##
## function output_type(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
## return error("Not implemented")
## end
##
## function output_type(f::typeof(map_nonzeros), fmap, as::SparseArrayLike...)
## return SparseArray
## end
##
## # Allocate an array with uninitialized/undefined memory
## # according the array type and structure (for example the
## # size or axes).
## function allocate(arraytype::Type{<:AbstractArray}, structure)
## return arraytype(undef, structure)
## end
##
## function allocate_zeros(arraytype::Type{<:AbstractArray}, structure)
## a = allocate(arraytype, structure)
## # TODO: Use `zeros!!` or `zerovector!!` from VectorInterface.jl.
## zeros!(a)
## return a
## end
##
## function allocate_output(f::typeof(map_nonzeros), fmap, as::AbstractArray...)
## return allocate_zeros(output_type(f, fmap, as...), output_structure(f, fmap, as...))
## end
##
function map_nonzeros(f, as::SparseArrayLike...)
@assert allequal(axes.(as))
a_dest = zero(first(as))
## @assert allequal(axes.(as))
# Preserves the element type:
# a_dest = zero(first(as))
a_dest = allocate_output(map_nonzeros, f, as...)
map!(f, a_dest, as...)
return a_dest
end
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/BlockSparseArrays/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
25 changes: 24 additions & 1 deletion NDTensors/src/BlockSparseArrays/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using BlockArrays: BlockArrays, blockedrange, blocksize
using BlockArrays: BlockArrays, BlockRange, blocksize
using NDTensors: contract
using NDTensors.BlockSparseArrays:
BlockSparseArrays, BlockSparseArray, gradedrange, nonzero_blockkeys, fusedims
using ITensors: QN
Expand Down Expand Up @@ -56,4 +57,26 @@ using ITensors: QN
# need to fix that.
@test_broken length(nonzero_blockkeys(B_fused)) == 2
end
@testset "contract" begin
function randn_even_blocks!(a)
for b in BlockRange(a)
if iseven(sum(b.n))
a[b] = randn(eltype(a), size(@view(a[b])))
end
end
end

d1, d2, d3, d4 = [2, 3], [3, 4], [4, 5], [5, 6]
elt = Float64
a1 = BlockSparseArray{elt}(d1, d2, d3)
randn_even_blocks!(a1)
a2 = BlockSparseArray{elt}(d2, d4)
randn_even_blocks!(a2)
a_dest, labels_dest = contract(a1, (1, -1, 2), a2, (-1, 3))
@show labels_dest == (1, 2, 3)

# TODO: Output `labels_dest` as well.
a_dest_dense = contract(Array(a1), (1, -1, 2), Array(a2), (-1, 3))
@show a_dest ≈ a_dest_dense
end
end
6 changes: 4 additions & 2 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@ include("arraystorage/diagonalarray/tensor/contract.jl")
# BlockSparseArray storage
include("arraystorage/blocksparsearray/storage/unwrap.jl")
include("arraystorage/blocksparsearray/storage/contract.jl")
include("arraystorage/blocksparsearray/storage/qr.jl")
include("arraystorage/blocksparsearray/storage/eigen.jl")
include("arraystorage/blocksparsearray/storage/svd.jl")

## TODO: Delete once it is rewritten for array storage types.
## include("arraystorage/blocksparsearray/tensor/combiner/contract_uncombine.jl")
include("arraystorage/blocksparsearray/tensor/contract.jl")

# Combiner storage
include("arraystorage/combiner/storage/promote_rule.jl")
Expand Down
Loading
Loading