From 3db3b192d178bd1de223d7c376ea6b3a997beb86 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 23 Oct 2023 14:29:41 -0400 Subject: [PATCH] Overload LinearAlgebra functions for Tensor decompositions --- NDTensors/src/tensor/linearalgebra.jl | 12 ++++++++++++ src/tensor_operations/matrix_decomposition.jl | 8 ++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/NDTensors/src/tensor/linearalgebra.jl b/NDTensors/src/tensor/linearalgebra.jl index 33722ce49a..484ce8811f 100644 --- a/NDTensors/src/tensor/linearalgebra.jl +++ b/NDTensors/src/tensor/linearalgebra.jl @@ -2,6 +2,18 @@ function LinearAlgebra.qr(T::Tensor; kwargs...) return qr(T; kwargs...) end +function LinearAlgebra.eigen(T::Tensor; kwargs...) + return eigen(T; kwargs...) +end + +function LinearAlgebra.eigen(T::Hermitian{<:Real,<:Tensor}; kwargs...) + return eigen(T; kwargs...) +end + +function LinearAlgebra.eigen(T::Hermitian{<:Complex{<:Real},<:Tensor}; kwargs...) + return eigen(T; kwargs...) +end + function LinearAlgebra.svd(T::Tensor; kwargs...) return svd(T; kwargs...) end diff --git a/src/tensor_operations/matrix_decomposition.jl b/src/tensor_operations/matrix_decomposition.jl index fede432847..21ca866a01 100644 --- a/src/tensor_operations/matrix_decomposition.jl +++ b/src/tensor_operations/matrix_decomposition.jl @@ -144,7 +144,7 @@ function svd(A::ITensor, Linds...; kwargs...) AC = permute(AC, cL, cR) end - USVT = NDTensors.svd(tensor(AC); kwargs...) + USVT = svd(tensor(AC); kwargs...) if isnothing(USVT) return nothing end @@ -337,7 +337,7 @@ function eigen(A::ITensor, Linds, Rinds; kwargs...) AT = ishermitian ? Hermitian(tensor(AC)) : tensor(AC) - DT, VT, spec = NDTensors.eigen(AT; kwargs...) + DT, VT, spec = eigen(AT; kwargs...) D, VC = itensor(DT), itensor(VT) V = VC * dag(CR) @@ -433,7 +433,7 @@ lq(A::ITensor, Linds...; kwargs...) = lq(A, Linds, uniqueinds(A, Linds); kwargs. # Handle default tags and dispatch to generic qx/xq functions. # function qr(A::ITensor, Linds::Indices, Rinds::Indices; tags=ts"Link,qr", kwargs...) - return qx(NDTensors.qr, A, Linds, Rinds; tags, kwargs...) + return qx(qr, A, Linds, Rinds; tags, kwargs...) end function ql(A::ITensor, Linds::Indices, Rinds::Indices; tags=ts"Link,ql", kwargs...) return qx(ql, A, Linds, Rinds; tags, kwargs...) @@ -442,7 +442,7 @@ function rq(A::ITensor, Linds::Indices, Rinds::Indices; tags=ts"Link,rq", kwargs return xq(ql, A, Linds, Rinds; tags, kwargs...) end function lq(A::ITensor, Linds::Indices, Rinds::Indices; tags=ts"Link,lq", kwargs...) - return xq(NDTensors.qr, A, Linds, Rinds; tags, kwargs...) + return xq(qr, A, Linds, Rinds; tags, kwargs...) end # # Generic function implementing both qr and ql decomposition. The X tensor = R or L.