Skip to content

Commit

Permalink
Add logic to fix CUDA qr with rectangular matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Nov 1, 2023
1 parent 66f5d39 commit 48a8c15
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
function NDTensors.svd_catch_error(A::CuMatrix; alg="JacobiAlgorithm")
function NDTensors.svd_catch_error(A::CuMatrix; alg::String="JacobiAlgorithm")
if alg == "JacobiAlgorithm"
alg = CUDA.CUSOLVER.JacobiAlgorithm()
else
alg = CUDA.CUSOLVER.QRAlgorithm()
end
return NDTensors.svd_catch_error(A, alg)
end

function NDTensors.svd_catch_error(A::CuMatrix, ::CUDA.CUSOLVER.JacobiAlgorithm)
USV = try
svd(expose(A); alg=alg)
svd(A; alg=CUDA.CUSOLVER.JacobiAlgorithm())
catch
return nothing
end
return USV
end

function NDTensors.svd_catch_error(A::CuMatrix, ::CUDA.CUSOLVER.QRAlgorithm)
s = size(A)
if s[1] < s[2]
At = copy(Adjoint(A))

Check warning on line 23 in NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:23:- NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:24:- USV = try NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:23:+ NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:24:+ USV = try
USV = try
svd(At; alg=CUDA.CUSOLVER.QRAlgorithm())
catch
return nothing
end
MV ,MS, MU = USV;

Check warning on line 29 in NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:29:- MV ,MS, MU = USV; NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:29:+ MV, MS, MU = USV
USV = SVD(copy(MU), MS, Adjoint(MV))
else
USV = try

Check warning on line 32 in NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:32:- USV = try NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:32:+ USV = try
svd(A; alg=CUDA.CUSOLVER.QRAlgorithm())
catch
return nothing
end
end
return USV
end

Check warning on line 39 in NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:39:-end NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl:39:+end

0 comments on commit 48a8c15

Please sign in to comment.