diff --git a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl index 89ee4d42a3..2a1d3ee5e5 100644 --- a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl +++ b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl @@ -2,9 +2,10 @@ module NDTensorsMetalExt using Adapt using Functors -using LinearAlgebra: LinearAlgebra, Transpose, mul! +using LinearAlgebra: LinearAlgebra, Transpose, mul!, qr, eigen, svd using NDTensors using NDTensors.SetParameters +using NDTensors.Unwrap: qr_positive, ql_positive, ql if isdefined(Base, :get_extension) using Metal diff --git a/NDTensors/src/linearalgebra/svd.jl b/NDTensors/src/linearalgebra/svd.jl index 6fc4fe592b..acb0c3fd66 100644 --- a/NDTensors/src/linearalgebra/svd.jl +++ b/NDTensors/src/linearalgebra/svd.jl @@ -42,7 +42,7 @@ function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int= V = M' * U - V, R = qr_positive(V) + V, R = qr_positive(expose(V)) D[1:Nd] = diag(R)[1:Nd] (done, start) = svd_recursive_state(D, thresh)