Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Get more Array storage functionality working #1222

Merged
merged 19 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,25 @@ include("empty/adapt.jl")

#####################################
# Array Tensor (experimental)
# TODO: Move to `Experimental` module.
# TODO: Move to `Experimental` module?
#
include("arraytensor/arraytensor.jl")
include("arraytensor/array.jl")
include("arraytensor/blocksparsearray.jl")
include("arraystorage/arraystorage/storage/arraystorage.jl")
include("arraystorage/arraystorage/storage/conj.jl")
include("arraystorage/arraystorage/storage/permutedims.jl")
include("arraystorage/arraystorage/storage/contract.jl")
include("arraystorage/arraystorage/storage/combiner.jl")

include("arraystorage/arraystorage/tensor/arraystorage.jl")
include("arraystorage/arraystorage/tensor/indexing.jl")
include("arraystorage/arraystorage/tensor/permutedims.jl")
include("arraystorage/arraystorage/tensor/mul.jl")
include("arraystorage/arraystorage/tensor/contract.jl")
include("arraystorage/arraystorage/tensor/qr.jl")
include("arraystorage/arraystorage/tensor/svd.jl")
include("arraystorage/arraystorage/tensor/combiner.jl")

# BlockSparseArray storage
include("arraystorage/blocksparsearray/storage/contract.jl")

#####################################
# Deprecations
Expand Down
30 changes: 30 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/storage/arraystorage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Used for dispatch to distinguish from Tensors wrapping TensorStorage.
# Remove once TensorStorage is removed.
const ArrayStorage{T,N} = Union{
Array{T,N},
ReshapedArray{T,N},
SubArray{T,N},
PermutedDimsArray{T,N},
StridedView{T,N},
BlockSparseArray{T,N},
}

const MatrixStorage{T} = Union{
ArrayStorage{T,2},
Transpose{T},
Adjoint{T},
Symmetric{T},
Hermitian{T},
UpperTriangular{T},
LowerTriangular{T},
UnitUpperTriangular{T},
UnitLowerTriangular{T},
Diagonal{T},
}

const MatrixOrArrayStorage{T} = Union{MatrixStorage{T},ArrayStorage{T}}

# TODO: Delete once `Dense` is removed.
function to_arraystorage(x::DenseTensor)
return tensor(reshape(data(x), size(x)), inds(x))
end
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
promote_rule(::Type{<:Combiner}, arraytype::Type{<:MatrixOrArrayStorage}) = arraytype
2 changes: 2 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/storage/conj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
conj(as::AliasStyle, A::AbstractArray) = conj(A)
conj(as::AllowAlias, A::Array{<:Real}) = A
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# Combiner
promote_rule(::Type{<:Combiner}, arraytype::Type{<:MatrixOrArrayStorage}) = arraytype

# Generic AbstractArray code
function contract(
array1::MatrixOrArrayStorage,
Expand Down Expand Up @@ -57,12 +54,3 @@ function contract!(
_contract!(arrayR, array1, array2, props)
return arrayR
end

function permutedims!(
output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function
)
output_array = permutedims!!(
leaf_parenttype(output_array), output_array, leaf_parenttype(array), array, perm, f
)
return output_array
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function permutedims!(
output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function
)
output_array = permutedims!!(
leaf_parenttype(output_array), output_array, leaf_parenttype(array), array, perm, f
)
return output_array
end
27 changes: 27 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/tensor/arraystorage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
const ArrayStorageTensor{T,N,S,I} = Tensor{T,N,S,I} where {S<:ArrayStorage{T,N}}
const MatrixStorageTensor{T,S,I} = Tensor{T,2,S,I} where {S<:MatrixStorage{T}}
const MatrixOrArrayStorageTensor{T,S,I} =
Tensor{T,N,S,I} where {N,S<:MatrixOrArrayStorage{T}}

function Tensor(storage::MatrixOrArrayStorageTensor, inds::Tuple)
return Tensor(NeverAlias(), storage, inds)
end

function Tensor(as::AliasStyle, storage::MatrixOrArrayStorage, inds::Tuple)
return Tensor{eltype(storage),length(inds),typeof(storage),typeof(inds)}(
as, storage, inds
)
end

array(tensor::MatrixOrArrayStorageTensor) = storage(tensor)

# Linear algebra (matrix algebra)
function Base.adjoint(tens::MatrixStorageTensor)
return tensor(adjoint(storage(tens)), reverse(inds(tens)))
end

# Conversion from a tensor with TensorStorage storage
# to AbstractArray storage,
function to_arraytensor(x::DenseTensor)
return tensor(reshape(data(x), size(x)), inds(x))
end
74 changes: 74 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/tensor/combiner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
function contraction_output(
tensor1::MatrixOrArrayStorageTensor, tensor2::CombinerTensor, indsR
)
tensortypeR = contraction_output_type(typeof(tensor1), typeof(tensor2), indsR)
return NDTensors.similar(tensortypeR, indsR)
end

function contract!!(
output_tensor::ArrayStorageTensor,
output_tensor_labels,
combiner_tensor::CombinerTensor,
combiner_tensor_labels,
tensor::ArrayStorageTensor,
tensor_labels,
)
if ndims(combiner_tensor) ≤ 1
# Empty combiner, acts as multiplying by 1
output_tensor = permutedims!!(
output_tensor, tensor, getperm(output_tensor_labels, tensor_labels)
)
return output_tensor
end
if is_index_replacement(tensor, tensor_labels, combiner_tensor, combiner_tensor_labels)
ui = setdiff(combiner_tensor_labels, tensor_labels)[]
newind = inds(combiner_tensor)[findfirst(==(ui), combiner_tensor_labels)]
cpos1, cpos2 = intersect_positions(combiner_tensor_labels, tensor_labels)
output_tensor_storage = copy(storage(tensor))
output_tensor_inds = setindex(inds(tensor), newind, cpos2)
return NDTensors.tensor(output_tensor_storage, output_tensor_inds)
end
is_combining_contraction = is_combining(
tensor, tensor_labels, combiner_tensor, combiner_tensor_labels
)
if is_combining_contraction
Alabels, Blabels = tensor_labels, combiner_tensor_labels
final_labels = contract_labels(Blabels, Alabels)
final_labels_n = contract_labels(combiner_tensor_labels, tensor_labels)
output_tensor_inds = inds(output_tensor)
if final_labels != final_labels_n
perm = getperm(final_labels_n, final_labels)
output_tensor_inds = permute(inds(output_tensor), perm)
output_tensor_labels = permute(output_tensor_labels, perm)
end
cpos1, output_tensor_cpos = intersect_positions(
combiner_tensor_labels, output_tensor_labels
)
labels_comb = deleteat(combiner_tensor_labels, cpos1)
output_tensor_vl = [output_tensor_labels...]
for (ii, li) in enumerate(labels_comb)
insert!(output_tensor_vl, output_tensor_cpos + ii, li)
end
deleteat!(output_tensor_vl, output_tensor_cpos)
labels_perm = tuple(output_tensor_vl...)
perm = getperm(labels_perm, tensor_labels)
# TODO: Add a `reshape` for `ArrayStorageTensor`.
## tensorp = reshape(output_tensor, NDTensors.permute(inds(tensor), perm))
tensorp_inds = permute(inds(tensor), perm)
tensorp = NDTensors.tensor(reshape(storage(output_tensor), dims(tensorp_inds)), tensorp_inds)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
permutedims!(tensorp, tensor, perm)
# TODO: Add a `reshape` for `ArrayStorageTensor`.
## reshape(tensorp, output_tensor_inds)
return NDTensors.tensor(reshape(storage(tensorp), dims(output_tensor_inds)), output_tensor_inds)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
else # Uncombining
cpos1, cpos2 = intersect_positions(combiner_tensor_labels, tensor_labels)
output_tensor_storage = copy(storage(tensor))
indsC = deleteat(inds(combiner_tensor), cpos1)
output_tensor_inds = insertat(inds(tensor), indsC, cpos2)
# TODO: Add a `reshape` for `ArrayStorageTensor`.
return NDTensors.tensor(reshape(output_tensor_storage, dims(output_tensor_inds)), output_tensor_inds)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
end
return invalid_combiner_contraction_error(
tensor, tensor_labels, combiner_tensor, combiner_tensor_labels
)
end
19 changes: 19 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/tensor/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# TODO: Just call `contraction_output(storage(tensor1), storage(tensor2), indsR)`
function contraction_output(
tensor1::MatrixOrArrayStorageTensor, tensor2::MatrixOrArrayStorageTensor, indsR
)
tensortypeR = contraction_output_type(typeof(tensor1), typeof(tensor2), indsR)
return NDTensors.similar(tensortypeR, indsR)
end

function contract!(
tensorR::MatrixOrArrayStorageTensor,
labelsR,
tensor1::MatrixOrArrayStorageTensor,
labels1,
tensor2::MatrixOrArrayStorageTensor,
labels2,
)
contract!(storage(tensorR), labelsR, storage(tensor1), labels1, storage(tensor2), labels2)
return tensorR
end
8 changes: 8 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/tensor/indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function getindex(tensor::MatrixOrArrayStorageTensor, I::Integer...)
return storage(tensor)[I...]
end

function setindex!(tensor::MatrixOrArrayStorageTensor, v, I::Integer...)
storage(tensor)[I...] = v
return tensor
end
6 changes: 6 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/tensor/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
function LinearAlgebra.mul!(
C::MatrixStorageTensor, A::MatrixStorageTensor, B::MatrixStorageTensor
)
mul!(storage(C), storage(A), storage(B))
return C
end
9 changes: 9 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/tensor/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
function permutedims!(
output_tensor::MatrixOrArrayStorageTensor,
tensor::MatrixOrArrayStorageTensor,
perm,
f::Function,
)
permutedims!(storage(output_tensor), storage(tensor), perm, f)
return output_tensor
end
10 changes: 10 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/tensor/qr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function qr(A::ArrayStorageTensor)
Q, R = qr(storage(A))
Q = convert(typeof(R), Q)
i, j = inds(A)
q = size(A, 1) < size(A, 2) ? i : j
q = sim(q)
Qₜ = tensor(Q, (i, q))
Rₜ = tensor(R, (dag(q), j))
return Qₜ, Rₜ
end
141 changes: 141 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/tensor/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# TODO: Rewrite this function to be more modern:
# 1. List keyword arguments in function signature.
# 2. Remove `Dense` and `Diag`.
# 3. Output `Spectrum` as a keyword argument that gets overwritten.
# 4. Dispatch on `alg`.
# 5. Remove keyword argument deprecations.
# 6. Make this into two layers, one that handles indices and one that works with `Matrix`.
"""
svd(T::DenseTensor{<:Number,2}; kwargs...)

svd of an order-2 DenseTensor
"""
function svd(T::ArrayStorageTensor{ElT,2,IndsT}; kwargs...) where {ElT,IndsT}
truncate = haskey(kwargs, :maxdim) || haskey(kwargs, :cutoff)

#
# Keyword argument deprecations
#
use_absolute_cutoff = false
if haskey(kwargs, :absoluteCutoff)
@warn "In svd, keyword argument absoluteCutoff is deprecated in favor of use_absolute_cutoff"
use_absolute_cutoff = get(kwargs, :absoluteCutoff, use_absolute_cutoff)
end

use_relative_cutoff = true
if haskey(kwargs, :doRelCutoff)
@warn "In svd, keyword argument doRelCutoff is deprecated in favor of use_relative_cutoff"
use_relative_cutoff = get(kwargs, :doRelCutoff, use_relative_cutoff)
end

if haskey(kwargs, :fastsvd) || haskey(kwargs, :fastSVD)
error(
"In svd, fastsvd/fastSVD keyword arguments are removed in favor of alg, see documentation for more details.",
)
end

maxdim::Int = get(kwargs, :maxdim, minimum(dims(T)))
mindim::Int = get(kwargs, :mindim, 1)
cutoff = get(kwargs, :cutoff, 0.0)
use_absolute_cutoff::Bool = get(kwargs, :use_absolute_cutoff, use_absolute_cutoff)
use_relative_cutoff::Bool = get(kwargs, :use_relative_cutoff, use_relative_cutoff)
alg::String = get(kwargs, :alg, "divide_and_conquer")

#@timeit_debug timer "dense svd" begin
if alg == "divide_and_conquer"
MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.DivideAndConquer())
if isnothing(MUSV)
# If "divide_and_conquer" fails, try "qr_iteration"
alg = "qr_iteration"
MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.QRIteration())
if isnothing(MUSV)
# If "qr_iteration" fails, try "recursive"
alg = "recursive"
MUSV = svd_recursive(matrix(T))
end
end
elseif alg == "qr_iteration"
MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.QRIteration())
if isnothing(MUSV)
# If "qr_iteration" fails, try "recursive"
alg = "recursive"
MUSV = svd_recursive(matrix(T))
end
elseif alg == "recursive"
MUSV = svd_recursive(matrix(T))
elseif alg == "QRAlgorithm" || alg == "JacobiAlgorithm"
MUSV = svd_catch_error(matrix(T); alg=alg)
else
error(
"svd algorithm $alg is not currently supported. Please see the documentation for currently supported algorithms.",
)
end
if isnothing(MUSV)
if any(isnan, T)
println("SVD failed, the matrix you were trying to SVD contains NaNs.")
else
println(lapack_svd_error_message(alg))
end
return nothing
end
MU, MS, MV = MUSV
conj!(MV)
#end # @timeit_debug

P = MS .^ 2
if truncate
P, truncerr, _ = truncate!!(
P; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
)
else
truncerr = 0.0
end
spec = Spectrum(P, truncerr)
dS = length(P)
if dS < length(MS)
MU = MU[:, 1:dS]
# Fails on some GPU backends like Metal.
# resize!(MS, dS)
MS = MS[1:dS]
MV = MV[:, 1:dS]
end

# Make the new indices to go onto U and V
u = eltype(IndsT)(dS)
v = eltype(IndsT)(dS)
Uinds = IndsT((ind(T, 1), u))
Sinds = IndsT((u, v))
Vinds = IndsT((ind(T, 2), v))
U = tensor(MU, Uinds)
S = tensor(Diag(MS), Sinds)
V = tensor(MV, Vinds)
return U, S, V, spec
end

## function svd(
## tens::ArrayStorageTensor;
## alg,
## which_decomp,
## tags,
## mindim,
## cutoff,
## eigen_perturbation,
## normalize,
## maxdim,
## )
## error("Not implemented")
## F = svd(storage(tens))
## U, S, V = F.U, F.S, F.Vt
## i, j = inds(tens)
## # TODO: Make this more general with a `similar_ind` function,
## # so the dimension can be determined from the length of `S`.
## min_ij = dim(i) ≤ dim(j) ? i : j
## α = sim(min_ij) # similar_ind(i, space(S))
## β = sim(min_ij) # similar_ind(i, space(S))
## Utensor = tensor(U, (i, α))
## # TODO: Remove conversion to `Diagonal` to make more general, or make a generic `Diagonal` concept that works for `BlockSparseArray`.
## # Used for now to avoid introducing wrapper types.
## Stensor = tensor(Diagonal(S), (α, β))
## Vtensor = tensor(V, (β, j))
## return Utensor, Stensor, Vtensor, Spectrum(nothing, 0.0)
## end
Loading
Loading