From ada189b2b7c1acf5420a808e556da842ba946967 Mon Sep 17 00:00:00 2001 From: kmp5VT Date: Wed, 20 Sep 2023 16:46:42 -0400 Subject: [PATCH] fix the adapt function for CUDA --- NDTensors/ext/NDTensorsCUDAExt/adapt.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/NDTensors/ext/NDTensorsCUDAExt/adapt.jl b/NDTensors/ext/NDTensorsCUDAExt/adapt.jl index f35ed98ac7..5d36f99f74 100644 --- a/NDTensors/ext/NDTensorsCUDAExt/adapt.jl +++ b/NDTensors/ext/NDTensorsCUDAExt/adapt.jl @@ -19,18 +19,16 @@ function Adapt.adapt_storage(adaptor::NDTensorCuArrayAdaptor, xs::AbstractArray) return isbits(xs) ? xs : CuArray{ElT,1,BufT}(xs) end -function Adapt.adapt_storage(arraytype::Type{<:CuArray}, xs::NDTensors.UnallocatedZeros) +function Adapt.adapt_storage(adaptor::NDTensorCuArrayAdaptor, xs::NDTensors.UnallocatedZeros) arraytype_specified_1 = set_unspecified_parameters( - arraytype, Position(1), get_parameter(xs, Position(1)) + CuArray, Position(1), get_parameter(xs, Position(1)) ) arraytype_specified_2 = set_unspecified_parameters( arraytype_specified_1, Position(2), get_parameter(xs, Position(2)) ) - arraytype_specified_3 = set_unspecified_parameters( - arraytype_specified_2, Position(3), get_parameter(xs, Position(3)) - ) - elt = get_parameter(arraytype_specified_3, Position(1)) - N = get_parameter(arraytype_specified_3, Position(2)) + + elt = get_parameter(arraytype_specified_2, Position(1)) + N = get_parameter(arraytype_specified_2, Position(2)) return NDTensors.UnallocatedZeros{ elt,N,CUDA.CuArray{elt,N,default_parameter(CuArray, Position(3))} }(