From d8cb8bdcded85a5e6c93c71f9b133b6eda3d7d55 Mon Sep 17 00:00:00 2001 From: kmp5VT Date: Fri, 14 Jun 2024 13:47:06 -0400 Subject: [PATCH] array necessary here because .= fails for CUDA gpu --- NDTensors/src/tensor/tensor.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NDTensors/src/tensor/tensor.jl b/NDTensors/src/tensor/tensor.jl index a255aa82ef..4ff25c46dd 100644 --- a/NDTensors/src/tensor/tensor.jl +++ b/NDTensors/src/tensor/tensor.jl @@ -369,7 +369,7 @@ function diag(tensor::Tensor) tensordiag = NDTensors.similar( dense(typeof(tensor)), eltype(tensor), (diaglength(tensor),) ) - tensordiag .= diagview(tensor) + array(tensordiag) .= diagview(tensor) return tensordiag end