Skip to content

Commit

Permalink
Add type based default_svd_alg
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Nov 2, 2023
1 parent 63498c4 commit f5c7701
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ else
end

include("imports.jl")
include("default_kwargs.jl")
include("set_types.jl")
include("iscu.jl")
include("adapt.jl")
Expand Down
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/default_kwargs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NDTensors.default_svd_alg(a, ::Type{<:CuArray}) = "qr_algorithm"
2 changes: 1 addition & 1 deletion NDTensors/src/arraystorage/arraystorage/tensor/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function tsvd(
# Only used by BlockSparse svd
min_blockdim=nothing,
)
alg = replace_nothing(alg, default_svd_alg(a))
alg = replace_nothing(alg, default_svd_alg(a, unwrap_type(a)))
USV = svd(Algorithm(alg), a)
if isnothing(USV)
if any(isnan, a)
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/default_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ replace_nothing(value, replacement) = value
default_maxdim(a) = minimum(size(a))
default_mindim(a) = true
default_cutoff(a) = zero(eltype(a))
default_svd_alg(a) = "divide_and_conquer"
default_svd_alg(a, ::Type{<:AbstractArray}) = "divide_and_conquer"
default_use_absolute_cutoff(a) = false
default_use_relative_cutoff(a) = true
2 changes: 1 addition & 1 deletion NDTensors/src/linearalgebra/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function svd(
# Only used by BlockSparse svd
min_blockdim=nothing,
) where {ElT,IndsT}
alg = replace_nothing(alg, default_svd_alg(T))
alg = replace_nothing(alg, default_svd_alg(T, unwrap_type(T)))
if alg == "divide_and_conquer"
MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.DivideAndConquer())
if isnothing(MUSV)
Expand Down

0 comments on commit f5c7701

Please sign in to comment.