From f5c7701e06df386739fd7f543f8fe7af2d08b1a3 Mon Sep 17 00:00:00 2001 From: kmp5VT Date: Thu, 2 Nov 2023 14:49:04 -0400 Subject: [PATCH] Add type based default_svd_alg --- NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl | 1 + NDTensors/ext/NDTensorsCUDAExt/default_kwargs.jl | 1 + NDTensors/src/arraystorage/arraystorage/tensor/svd.jl | 2 +- NDTensors/src/default_kwargs.jl | 2 +- NDTensors/src/linearalgebra/linearalgebra.jl | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) create mode 100644 NDTensors/ext/NDTensorsCUDAExt/default_kwargs.jl diff --git a/NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl b/NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl index 1841bcc9a4..c81df9ed87 100644 --- a/NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl +++ b/NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl @@ -18,6 +18,7 @@ else end include("imports.jl") +include("default_kwargs.jl") include("set_types.jl") include("iscu.jl") include("adapt.jl") diff --git a/NDTensors/ext/NDTensorsCUDAExt/default_kwargs.jl b/NDTensors/ext/NDTensorsCUDAExt/default_kwargs.jl new file mode 100644 index 0000000000..67f943ef78 --- /dev/null +++ b/NDTensors/ext/NDTensorsCUDAExt/default_kwargs.jl @@ -0,0 +1 @@ +NDTensors.default_svd_alg(a, ::Type{<:CuArray}) = "qr_algorithm" \ No newline at end of file diff --git a/NDTensors/src/arraystorage/arraystorage/tensor/svd.jl b/NDTensors/src/arraystorage/arraystorage/tensor/svd.jl index 69f446ac53..89e5c4f8d7 100644 --- a/NDTensors/src/arraystorage/arraystorage/tensor/svd.jl +++ b/NDTensors/src/arraystorage/arraystorage/tensor/svd.jl @@ -51,7 +51,7 @@ function tsvd( # Only used by BlockSparse svd min_blockdim=nothing, ) - alg = replace_nothing(alg, default_svd_alg(a)) + alg = replace_nothing(alg, default_svd_alg(a, unwrap_type(a))) USV = svd(Algorithm(alg), a) if isnothing(USV) if any(isnan, a) diff --git a/NDTensors/src/default_kwargs.jl b/NDTensors/src/default_kwargs.jl index 73d8756d84..8209d820fb 100644 --- a/NDTensors/src/default_kwargs.jl +++ b/NDTensors/src/default_kwargs.jl @@ -4,6 +4,6 @@ replace_nothing(value, replacement) = value default_maxdim(a) = minimum(size(a)) default_mindim(a) = true default_cutoff(a) = zero(eltype(a)) -default_svd_alg(a) = "divide_and_conquer" +default_svd_alg(a, ::Type{<:AbstractArray}) = "divide_and_conquer" default_use_absolute_cutoff(a) = false default_use_relative_cutoff(a) = true diff --git a/NDTensors/src/linearalgebra/linearalgebra.jl b/NDTensors/src/linearalgebra/linearalgebra.jl index 6dd793eba8..f15c9fe2a9 100644 --- a/NDTensors/src/linearalgebra/linearalgebra.jl +++ b/NDTensors/src/linearalgebra/linearalgebra.jl @@ -105,7 +105,7 @@ function svd( # Only used by BlockSparse svd min_blockdim=nothing, ) where {ElT,IndsT} - alg = replace_nothing(alg, default_svd_alg(T)) + alg = replace_nothing(alg, default_svd_alg(T, unwrap_type(T))) if alg == "divide_and_conquer" MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.DivideAndConquer()) if isnothing(MUSV)