From 2b3b9b5c5a007942501645197aa3093c1c98b6b5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 Nov 2023 14:10:46 -0500 Subject: [PATCH] [NDTensors] Use PackageExtensionCompat --- NDTensors/Project.toml | 4 +-- .../ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl | 13 +++------- .../NDTensorsMetalExt/NDTensorsMetalExt.jl | 6 +---- .../NDTensorsOctavianExt.jl | 7 +----- .../NDTensorsTBLISExt/NDTensorsTBLISExt.jl | 7 +----- NDTensors/src/NDTensors.jl | 25 ++----------------- 6 files changed, 10 insertions(+), 52 deletions(-) diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index a7556b4680..ad67806ba2 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -14,8 +14,8 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -47,8 +47,8 @@ Functors = "0.2, 0.3, 0.4" HDF5 = "0.14, 0.15, 0.16, 0.17" InlineStrings = "1" LinearAlgebra = "1.6" +PackageExtensionCompat = "1" Random = "1.6" -Requires = "1.1" SimpleTraits = "0.9.4" SplitApplyCombine = "1.2.2" StaticArrays = "0.12, 1.0" diff --git a/NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl b/NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl index c81df9ed87..f9e8813152 100644 --- a/NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl +++ b/NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl @@ -6,16 +6,9 @@ using NDTensors.Unwrap using Adapt using Functors using LinearAlgebra - -if isdefined(Base, :get_extension) - using CUDA - using CUDA.CUBLAS - using CUDA.CUSOLVER -else - using ..CUDA - using .CUBLAS - using .CUSOLVER -end +using CUDA +using CUDA.CUBLAS +using CUDA.CUSOLVER include("imports.jl") include("default_kwargs.jl") diff --git a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl index 36d72d7209..c90e34f94b 100644 --- a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl +++ b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl @@ -7,11 +7,7 @@ using NDTensors using NDTensors.SetParameters using NDTensors.Unwrap: qr_positive, ql_positive, ql -if isdefined(Base, :get_extension) - using Metal -else - using ..Metal -end +using Metal include("imports.jl") include("adapt.jl") diff --git a/NDTensors/ext/NDTensorsOctavianExt/NDTensorsOctavianExt.jl b/NDTensors/ext/NDTensorsOctavianExt/NDTensorsOctavianExt.jl index 0fe1dc3a46..c25dd84870 100644 --- a/NDTensors/ext/NDTensorsOctavianExt/NDTensorsOctavianExt.jl +++ b/NDTensors/ext/NDTensorsOctavianExt/NDTensorsOctavianExt.jl @@ -1,12 +1,7 @@ module NDTensorsOctavianExt using NDTensors - -if isdefined(Base, :get_extension) - using Octavian -else - using ..Octavian -end +using Octavian include("import.jl") include("octavian.jl") diff --git a/NDTensors/ext/NDTensorsTBLISExt/NDTensorsTBLISExt.jl b/NDTensors/ext/NDTensorsTBLISExt/NDTensorsTBLISExt.jl index 22bdcc7692..7edc3c78d9 100644 --- a/NDTensors/ext/NDTensorsTBLISExt/NDTensorsTBLISExt.jl +++ b/NDTensors/ext/NDTensorsTBLISExt/NDTensorsTBLISExt.jl @@ -2,12 +2,7 @@ module NDTensorsTBLISExt using NDTensors using LinearAlgebra -if isdefined(Base, :get_extension) - using TBLIS -else - using ..TBLIS -end -isdefined(Base, :get_extension) ? (using TBLIS) : (using ..TBLIS) +using TBLIS import NDTensors.contract! diff --git a/NDTensors/src/NDTensors.jl b/NDTensors/src/NDTensors.jl index f788052a73..aee947e2aa 100644 --- a/NDTensors/src/NDTensors.jl +++ b/NDTensors/src/NDTensors.jl @@ -295,10 +295,6 @@ end # Optional backends # -if !isdefined(Base, :get_extension) - using Requires -end - const _using_tblis = Ref(false) using_tblis() = _using_tblis[] @@ -315,26 +311,9 @@ end function backend_octavian end +using PackageExtensionCompat function __init__() - @static if !isdefined(Base, :get_extension) - @require TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9" begin - enable_tblis() - include("../ext/NDTensorsTBLISExt/NDTensorsTBLISExt.jl") - end - @require Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" begin - include("../ext/NDTensorsOctavianExt/NDTensorsOctavianExt.jl") - end - - @require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin - if CUDA.functional() - include("../ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl") - end - end - - @require Metal = "dde4c033-4e86-420c-a63e-0dd931031962" begin - include("../ext/NDTensorsMetalExt/NDTensorsMetalExt.jl") - end - end + @require_extensions end end # module NDTensors