-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensors] Get more
Array
storage functionality working (#1222)
- Loading branch information
Showing
26 changed files
with
622 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
NDTensors/src/arraystorage/arraystorage/storage/arraystorage.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Used for dispatch to distinguish from Tensors wrapping TensorStorage. | ||
# Remove once TensorStorage is removed. | ||
const ArrayStorage{T,N} = Union{ | ||
Array{T,N}, | ||
ReshapedArray{T,N}, | ||
SubArray{T,N}, | ||
PermutedDimsArray{T,N}, | ||
StridedView{T,N}, | ||
BlockSparseArray{T,N}, | ||
} | ||
|
||
const MatrixStorage{T} = Union{ | ||
ArrayStorage{T,2}, | ||
Transpose{T}, | ||
Adjoint{T}, | ||
Symmetric{T}, | ||
Hermitian{T}, | ||
UpperTriangular{T}, | ||
LowerTriangular{T}, | ||
UnitUpperTriangular{T}, | ||
UnitLowerTriangular{T}, | ||
Diagonal{T}, | ||
} | ||
|
||
const MatrixOrArrayStorage{T} = Union{MatrixStorage{T},ArrayStorage{T}} | ||
|
||
# TODO: Delete once `Dense` is removed. | ||
function to_arraystorage(x::DenseTensor) | ||
return tensor(reshape(data(x), size(x)), inds(x)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
conj(as::AliasStyle, A::AbstractArray) = conj(A) | ||
conj(as::AllowAlias, A::Array{<:Real}) = A |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 8 additions & 0 deletions
8
NDTensors/src/arraystorage/arraystorage/storage/permutedims.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
function permutedims!( | ||
output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function | ||
) | ||
output_array = permutedims!!( | ||
leaf_parenttype(output_array), output_array, leaf_parenttype(array), array, perm, f | ||
) | ||
return output_array | ||
end |
22 changes: 22 additions & 0 deletions
22
NDTensors/src/arraystorage/arraystorage/tensor/arraystorage.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
const ArrayStorageTensor{T,N,S,I} = Tensor{T,N,S,I} where {S<:ArrayStorage{T,N}} | ||
const MatrixStorageTensor{T,S,I} = Tensor{T,2,S,I} where {S<:MatrixStorage{T}} | ||
const MatrixOrArrayStorageTensor{T,S,I} = | ||
Tensor{T,N,S,I} where {N,S<:MatrixOrArrayStorage{T}} | ||
|
||
function Tensor(storage::MatrixOrArrayStorageTensor, inds::Tuple) | ||
return Tensor(NeverAlias(), storage, inds) | ||
end | ||
|
||
function Tensor(as::AliasStyle, storage::MatrixOrArrayStorage, inds::Tuple) | ||
return Tensor{eltype(storage),length(inds),typeof(storage),typeof(inds)}( | ||
as, storage, inds | ||
) | ||
end | ||
|
||
array(tensor::MatrixOrArrayStorageTensor) = storage(tensor) | ||
|
||
# Linear algebra (matrix algebra) | ||
# TODO: Remove `Base.`? Is it imported? | ||
function Base.adjoint(tens::MatrixStorageTensor) | ||
return tensor(adjoint(storage(tens)), reverse(inds(tens))) | ||
end |
31 changes: 31 additions & 0 deletions
31
NDTensors/src/arraystorage/arraystorage/tensor/contract.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# TODO: Just call `contraction_output(storage(tensor1), storage(tensor2), indsR)` | ||
function contraction_output( | ||
tensor1::MatrixOrArrayStorageTensor, tensor2::MatrixOrArrayStorageTensor, indsR | ||
) | ||
tensortypeR = contraction_output_type(typeof(tensor1), typeof(tensor2), indsR) | ||
return NDTensors.similar(tensortypeR, indsR) | ||
end | ||
|
||
# TODO: Define `default_α` and `default_β`. | ||
function contract!( | ||
tensor_dest::MatrixOrArrayStorageTensor, | ||
labels_dest, | ||
tensor1::MatrixOrArrayStorageTensor, | ||
labels1, | ||
tensor2::MatrixOrArrayStorageTensor, | ||
labels2, | ||
α=one(eltype(tensor_dest)), | ||
β=zero(eltype(tensor_dest)); | ||
) | ||
contract!( | ||
storage(tensor_dest), | ||
labels_dest, | ||
storage(tensor1), | ||
labels1, | ||
storage(tensor2), | ||
labels2, | ||
α, | ||
β, | ||
) | ||
return tensor_dest | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# TODO: Rewrite this function to be more modern: | ||
# 1. List keyword arguments in function signature. | ||
# 2. Output `Spectrum` as a keyword argument that gets overwritten. | ||
# 3. Make this into two layers, one that handles indices and one that works with `AbstractMatrix`. | ||
function eigen( | ||
T::Hermitian{<:Any,<:ArrayStorageTensor}; | ||
maxdim=nothing, | ||
mindim=1, | ||
cutoff=nothing, | ||
use_absolute_cutoff=false, | ||
use_relative_cutoff=true, | ||
# These are getting passed erroneously. | ||
# TODO: Make sure they don't get passed down | ||
# to here. | ||
which_decomp=nothing, | ||
tags=nothing, | ||
eigen_perturbation=nothing, | ||
normalize=nothing, | ||
ishermitian=nothing, | ||
ortho=nothing, | ||
svd_alg=nothing, | ||
) | ||
truncate = !isnothing(maxdim) || !isnothing(cutoff) | ||
# TODO: Define `default_maxdim(T)`. | ||
maxdim = isnothing(maxdim) ? minimum(dims(T)) : maxdim | ||
# TODO: Define `default_cutoff(T)`. | ||
cutoff = isnothing(cutoff) ? zero(eltype(T)) : cutoff | ||
|
||
matrixT = matrix(T) | ||
## TODO Here I am calling parent to ensure that the correct `any` function | ||
## is envoked for non-cpu matrices | ||
if any(!isfinite, parent(matrixT)) | ||
throw( | ||
ArgumentError( | ||
"Trying to perform the eigendecomposition of a matrix containing NaNs or Infs" | ||
), | ||
) | ||
end | ||
|
||
DM, VM = eigen(matrixT) | ||
|
||
# Sort by largest to smallest eigenvalues | ||
# TODO: Replace `cpu` with `leaf_parenttype` dispatch. | ||
p = sortperm(cpu(DM); rev=true, by=abs) | ||
DM = DM[p] | ||
VM = VM[:, p] | ||
|
||
if truncate | ||
DM, truncerr, _ = truncate!!( | ||
DM; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff | ||
) | ||
dD = length(DM) | ||
if dD < size(VM, 2) | ||
VM = VM[:, 1:dD] | ||
end | ||
else | ||
dD = length(DM) | ||
truncerr = 0.0 | ||
end | ||
spec = Spectrum(DM, truncerr) | ||
|
||
# Make the new indices to go onto V | ||
# TODO: Put in a separate function, such as | ||
# `rewrap_inds` or something like that. | ||
indstype = typeof(inds(T)) | ||
l = eltype(indstype)(dD) | ||
r = eltype(indstype)(dD) | ||
Vinds = indstype((dag(ind(T, 2)), dag(r))) | ||
Dinds = indstype((l, dag(r))) | ||
V = tensor(VM, Vinds) | ||
# TODO: Replace with `DiagonalArray`. | ||
D = tensor(Diag(DM), Dinds) | ||
return D, V, spec | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
function getindex(tensor::MatrixOrArrayStorageTensor, I::Integer...) | ||
return storage(tensor)[I...] | ||
end | ||
|
||
function setindex!(tensor::MatrixOrArrayStorageTensor, v, I::Integer...) | ||
storage(tensor)[I...] = v | ||
return tensor | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
function LinearAlgebra.mul!( | ||
C::MatrixStorageTensor, A::MatrixStorageTensor, B::MatrixStorageTensor | ||
) | ||
mul!(storage(C), storage(A), storage(B)) | ||
return C | ||
end |
9 changes: 9 additions & 0 deletions
9
NDTensors/src/arraystorage/arraystorage/tensor/permutedims.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
function permutedims!( | ||
output_tensor::MatrixOrArrayStorageTensor, | ||
tensor::MatrixOrArrayStorageTensor, | ||
perm, | ||
f::Function, | ||
) | ||
permutedims!(storage(output_tensor), storage(tensor), perm, f) | ||
return output_tensor | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
function qr(A::ArrayStorageTensor) | ||
Q, R = qr(storage(A)) | ||
Q = convert(typeof(R), Q) | ||
i, j = inds(A) | ||
q = size(A, 1) < size(A, 2) ? i : j | ||
q = sim(q) | ||
Qₜ = tensor(Q, (i, q)) | ||
Rₜ = tensor(R, (dag(q), j)) | ||
return Qₜ, Rₜ | ||
end |
Oops, something went wrong.