Skip to content

Commit

Permalink
[NDTensors] DiagonalArray tensor operations (#1226)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Nov 1, 2023
1 parent 62f162f commit 70967cd
Show file tree
Hide file tree
Showing 23 changed files with 690 additions and 486 deletions.
28 changes: 12 additions & 16 deletions NDTensors/src/DiagonalArrays/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,32 @@ A Julia `DiagonalArray` type.
````julia
using NDTensors.DiagonalArrays:
DiagonalArray,
densearray,
diagview,
diaglength,
getdiagindex,
setdiagindex!,
setdiag!,
diagcopyto!

d = DiagonalArray([1., 2, 3], 3, 4, 5)
DiagIndex,
DiagIndices,
densearray

d = DiagonalArray([1.0, 2, 3], 3, 4, 5)
@show d[1, 1, 1] == 1
@show d[2, 2, 2] == 2
@show d[1, 2, 1] == 0

d[2, 2, 2] = 22
@show d[2, 2, 2] == 22

@show diaglength(d) == 3
@show length(d[DiagIndices()]) == 3
@show densearray(d) == d
@show getdiagindex(d, 2) == d[2, 2, 2]
@show d[DiagIndex(2)] == d[2, 2, 2]

setdiagindex!(d, 222, 2)
d[DiagIndex(2)] = 222
@show d[2, 2, 2] == 222

a = randn(3, 4, 5)
new_diag = randn(3)
setdiag!(a, new_diag)
diagcopyto!(d, a)
a[DiagIndices()] = new_diag
d[DiagIndices()] = a[DiagIndices()]

@show diagview(a) == new_diag
@show diagview(d) == new_diag
@show a[DiagIndices()] == new_diag
@show d[DiagIndices()] == new_diag
````

You can generate this README with:
Expand Down
24 changes: 8 additions & 16 deletions NDTensors/src/DiagonalArrays/examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,7 @@
#
# A Julia `DiagonalArray` type.

using NDTensors.DiagonalArrays:
DiagonalArray,
densearray,
diagview,
diaglength,
getdiagindex,
setdiagindex!,
setdiag!,
diagcopyto!
using NDTensors.DiagonalArrays: DiagonalArray, DiagIndex, DiagIndices, densearray

d = DiagonalArray([1.0, 2, 3], 3, 4, 5)
@show d[1, 1, 1] == 1
Expand All @@ -20,20 +12,20 @@ d = DiagonalArray([1.0, 2, 3], 3, 4, 5)
d[2, 2, 2] = 22
@show d[2, 2, 2] == 22

@show diaglength(d) == 3
@show length(d[DiagIndices()]) == 3
@show densearray(d) == d
@show getdiagindex(d, 2) == d[2, 2, 2]
@show d[DiagIndex(2)] == d[2, 2, 2]

setdiagindex!(d, 222, 2)
d[DiagIndex(2)] = 222
@show d[2, 2, 2] == 222

a = randn(3, 4, 5)
new_diag = randn(3)
setdiag!(a, new_diag)
diagcopyto!(d, a)
a[DiagIndices()] = new_diag
d[DiagIndices()] = a[DiagIndices()]

@show diagview(a) == new_diag
@show diagview(d) == new_diag
@show a[DiagIndices()] == new_diag
@show d[DiagIndices()] == new_diag

# You can generate this README with:
# ```julia
Expand Down
24 changes: 23 additions & 1 deletion NDTensors/src/DiagonalArrays/src/DiagonalArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DiagonalArrays
using Compat # allequal
using LinearAlgebra

export DiagonalArray
export DiagonalArray, DiagonalMatrix, DiagonalVector, DiagIndex, DiagIndices, densearray

include("diagview.jl")

Expand All @@ -19,6 +19,9 @@ struct DiagonalArray{T,N,Diag<:AbstractVector{T},Zero} <: AbstractArray{T,N}
zero::Zero
end

const DiagonalVector{T,Diag,Zero} = DiagonalArray{T,1,Diag,Zero}
const DiagonalMatrix{T,Diag,Zero} = DiagonalArray{T,2,Diag,Zero}

function DiagonalArray{T,N}(
diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
) where {T,N}
Expand Down Expand Up @@ -53,6 +56,25 @@ function DiagonalArray(diag::AbstractVector{T}, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

default_size(diag::AbstractVector, n) = ntuple(Returns(length(diag)), n)

# Infer size from diagonal
function DiagonalArray{T,N}(diag::AbstractVector) where {T,N}
return DiagonalArray{T,N}(diag, default_size(diag, N))
end

function DiagonalArray{<:Any,N}(diag::AbstractVector{T}) where {T,N}
return DiagonalArray{T,N}(diag)
end

function DiagonalMatrix(diag::AbstractVector)
return DiagonalArray{<:Any,2}(diag)
end

function DiagonalVector(diag::AbstractVector)
return DiagonalArray{<:Any,1}(diag)
end

# undef
function DiagonalArray{T,N}(::UndefInitializer, d::Tuple{Vararg{Int,N}}) where {T,N}
return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d)
Expand Down
46 changes: 37 additions & 9 deletions NDTensors/src/DiagonalArrays/src/diagview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ function setdiagindex!(a::AbstractArray, v, i::Integer)
return a
end

struct DiagIndex
I::Int
end

function Base.getindex(a::AbstractArray, i::DiagIndex)
return getdiagindex(a, i.I)
end

function Base.setindex!(a::AbstractArray, v, i::DiagIndex)
setdiagindex!(a, v, i.I)
return a
end

function setdiag!(a::AbstractArray, v)
copyto!(diagview(a), v)
return a
Expand All @@ -28,27 +41,42 @@ function diaglength(a::AbstractArray)
return minimum(size(a))
end

function diagstride(A::AbstractArray)
function diagstride(a::AbstractArray)
s = 1
p = 1
for i in 1:(ndims(A) - 1)
p *= size(A, i)
for i in 1:(ndims(a) - 1)
p *= size(a, i)
s += p
end
return s
end

function diagindices(A::AbstractArray)
diaglength = minimum(size(A))
maxdiag = LinearIndices(A)[CartesianIndex(ntuple(Returns(diaglength), ndims(A)))]
return 1:diagstride(A):maxdiag
function diagindices(a::AbstractArray)
diaglength = minimum(size(a))
maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength), ndims(a)))]
return 1:diagstride(a):maxdiag
end

function diagview(A::AbstractArray)
return @view A[diagindices(A)]
function diagindices(a::AbstractArray{<:Any,0})
return Base.OneTo(1)
end

function diagview(a::AbstractArray)
return @view a[diagindices(a)]
end

function diagcopyto!(dest::AbstractArray, src::AbstractArray)
copyto!(diagview(dest), diagview(src))
return dest
end

struct DiagIndices end

function Base.getindex(a::AbstractArray, ::DiagIndices)
return diagview(a)
end

function Base.setindex!(a::AbstractArray, v, ::DiagIndices)
setdiag!(a, v)
return a
end
3 changes: 3 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ include("exports.jl")
#####################################
# General functionality
#
include("default_kwargs.jl")
include("algorithm.jl")
include("aliasstyle.jl")
include("abstractarray/set_types.jl")
Expand Down Expand Up @@ -151,6 +152,8 @@ include("arraystorage/arraystorage/tensor/eigen.jl")
include("arraystorage/arraystorage/tensor/svd.jl")

# DiagonalArray storage
include("arraystorage/diagonalarray/storage/contract.jl")

include("arraystorage/diagonalarray/tensor/contract.jl")

# BlockSparseArray storage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const ArrayStorage{T,N} = Union{
SubArray{T,N},
PermutedDimsArray{T,N},
StridedView{T,N},
DiagonalArray{T,N},
BlockSparseArray{T,N},
}

Expand All @@ -28,3 +29,8 @@ const MatrixOrArrayStorage{T} = Union{MatrixStorage{T},ArrayStorage{T}}
function to_arraystorage(x::DenseTensor)
return tensor(reshape(data(x), size(x)), inds(x))
end

# TODO: Delete once `Diag` is removed.
function to_arraystorage(x::DiagTensor)
return tensor(DiagonalArray(data(x), size(x)), inds(x))
end
5 changes: 3 additions & 2 deletions NDTensors/src/arraystorage/arraystorage/storage/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ function contract(
labels1,
array2::MatrixOrArrayStorage,
labels2,
labelsR=contract_labels(labels1, labels2),
labelsR=contract_labels(labels1, labels2);
kwargs...,
)
output_array = contraction_output(array1, labels1, array2, labels2, labelsR)
contract!(output_array, labelsR, array1, labels1, array2, labels2)
contract!(output_array, labelsR, array1, labels1, array2, labels2; kwargs...)
return output_array
end

Expand Down
17 changes: 5 additions & 12 deletions NDTensors/src/arraystorage/arraystorage/tensor/eigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
function eigen(
T::Hermitian{<:Any,<:ArrayStorageTensor};
maxdim=nothing,
mindim=1,
mindim=nothing,
cutoff=nothing,
use_absolute_cutoff=false,
use_relative_cutoff=true,
use_absolute_cutoff=nothing,
use_relative_cutoff=nothing,
# These are getting passed erroneously.
# TODO: Make sure they don't get passed down
# to here.
Expand All @@ -20,12 +20,6 @@ function eigen(
ortho=nothing,
svd_alg=nothing,
)
truncate = !isnothing(maxdim) || !isnothing(cutoff)
# TODO: Define `default_maxdim(T)`.
maxdim = isnothing(maxdim) ? minimum(dims(T)) : maxdim
# TODO: Define `default_cutoff(T)`.
cutoff = isnothing(cutoff) ? zero(eltype(T)) : cutoff

matrixT = matrix(T)
## TODO Here I am calling parent to ensure that the correct `any` function
## is envoked for non-cpu matrices
Expand All @@ -45,7 +39,7 @@ function eigen(
DM = DM[p]
VM = VM[:, p]

if truncate
if any(!isnothing, (maxdim, cutoff))
DM, truncerr, _ = truncate!!(
DM; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff
)
Expand All @@ -68,7 +62,6 @@ function eigen(
Vinds = indstype((dag(ind(T, 2)), dag(r)))
Dinds = indstype((l, dag(r)))
V = tensor(VM, Vinds)
# TODO: Replace with `DiagonalArray`.
D = tensor(Diag(DM), Dinds)
D = tensor(DiagonalMatrix(DM), Dinds)
return D, V, spec
end
3 changes: 2 additions & 1 deletion NDTensors/src/arraystorage/arraystorage/tensor/qr.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
function qr(A::ArrayStorageTensor)
function qr(A::ArrayStorageTensor; positive=false)
positive && error("Not implemented")
Q, R = qr(storage(A))
Q = convert(typeof(R), Q)
i, j = inds(A)
Expand Down
Loading

0 comments on commit 70967cd

Please sign in to comment.