diff --git a/NDTensors/ext/NDTensorsCUDAExt/adapt.jl b/NDTensors/ext/NDTensorsCUDAExt/adapt.jl index 5d36f99f74..646bdaf67e 100644 --- a/NDTensors/ext/NDTensorsCUDAExt/adapt.jl +++ b/NDTensors/ext/NDTensorsCUDAExt/adapt.jl @@ -19,14 +19,16 @@ function Adapt.adapt_storage(adaptor::NDTensorCuArrayAdaptor, xs::AbstractArray) return isbits(xs) ? xs : CuArray{ElT,1,BufT}(xs) end -function Adapt.adapt_storage(adaptor::NDTensorCuArrayAdaptor, xs::NDTensors.UnallocatedZeros) +function Adapt.adapt_storage( + adaptor::NDTensorCuArrayAdaptor, xs::NDTensors.UnallocatedZeros +) arraytype_specified_1 = set_unspecified_parameters( CuArray, Position(1), get_parameter(xs, Position(1)) ) arraytype_specified_2 = set_unspecified_parameters( arraytype_specified_1, Position(2), get_parameter(xs, Position(2)) ) - + elt = get_parameter(arraytype_specified_2, Position(1)) N = get_parameter(arraytype_specified_2, Position(2)) return NDTensors.UnallocatedZeros{