From 44ec8e8a6e787a2437c8d61b927ba50bc05b94ce Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 3 Nov 2023 11:48:50 -0400 Subject: [PATCH] [NDTensorsMetalExt] Fix issues not importing some LinearAlgebra functions --- NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl | 3 ++- NDTensors/src/linearalgebra/svd.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl index 89ee4d42a3..2a1d3ee5e5 100644 --- a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl +++ b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl @@ -2,9 +2,10 @@ module NDTensorsMetalExt using Adapt using Functors -using LinearAlgebra: LinearAlgebra, Transpose, mul! +using LinearAlgebra: LinearAlgebra, Transpose, mul!, qr, eigen, svd using NDTensors using NDTensors.SetParameters +using NDTensors.Unwrap: qr_positive, ql_positive, ql if isdefined(Base, :get_extension) using Metal diff --git a/NDTensors/src/linearalgebra/svd.jl b/NDTensors/src/linearalgebra/svd.jl index 6fc4fe592b..acb0c3fd66 100644 --- a/NDTensors/src/linearalgebra/svd.jl +++ b/NDTensors/src/linearalgebra/svd.jl @@ -42,7 +42,7 @@ function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int= V = M' * U - V, R = qr_positive(V) + V, R = qr_positive(expose(V)) D[1:Nd] = diag(R)[1:Nd] (done, start) = svd_recursive_state(D, thresh)