From e1939b673895400c5718cf83337dce86323b9bb8 Mon Sep 17 00:00:00 2001 From: Karl Pierce Date: Fri, 19 Jan 2024 19:16:01 -0500 Subject: [PATCH] =?UTF-8?q?[NDTensors]=20Replace=20`axpby`=20with=20broadc?= =?UTF-8?q?ast=20when=20=CE=B2=20=3D=3D=200=20(#1309)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- NDTensors/src/dense/tensoralgebra/contract.jl | 6 ++--- NDTensors/test/test_dense.jl | 26 +++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/NDTensors/src/dense/tensoralgebra/contract.jl b/NDTensors/src/dense/tensoralgebra/contract.jl index 5171507ce6..10b13d14ca 100644 --- a/NDTensors/src/dense/tensoralgebra/contract.jl +++ b/NDTensors/src/dense/tensoralgebra/contract.jl @@ -68,8 +68,7 @@ function _contract_scalar_noperm!( if iszero(α) fill!(Rᵈ, 0) else - # Rᵈ .= α .* T₂ᵈ - LinearAlgebra.axpby!(α, Tᵈ, β, Rᵈ) + Rᵈ .= α .* Tᵈ end elseif isone(β) if iszero(α) @@ -81,8 +80,7 @@ function _contract_scalar_noperm!( end else if iszero(α) - # Rᵈ .= β .* Rᵈ - LinearAlgebra.scal!(length(Rᵈ), β, Rᵈ, 1) + Rᵈ .= β .* Rᵈ else # Rᵈ .= α .* Tᵈ .+ β .* Rᵈ LinearAlgebra.axpby!(α, Tᵈ, β, Rᵈ) diff --git a/NDTensors/test/test_dense.jl b/NDTensors/test/test_dense.jl index 2180877e4a..7cd2be6ffb 100644 --- a/NDTensors/test/test_dense.jl +++ b/NDTensors/test/test_dense.jl @@ -232,6 +232,32 @@ using .NDTensorsTestUtils: devices_list end # Only CPU backend testing + @testset "Contract with exotic types" begin + # BigFloat is not supported on GPU + ## randn(BigFloat, ...) is not defined in Julia 1.6 + a = BigFloat.(randn(Float64, 2, 3)) + t = Tensor(a, (1, 2, 3)) + m = Tensor(a, (2, 3)) + v = Tensor([one(BigFloat)], (1,)) + + @test m ≈ contract(t, (-1, 2, 3), v, (-1,)) + tp = similar(t) + NDTensors.contract!(tp, (1, 2, 3), t, (1, 2, 3), v, (1,), false, false) + @test iszero(tp) + + fill!(tp, one(BigFloat)) + NDTensors.contract!(tp, (1, 2, 3), t, (1, 2, 3), v, (1,), false, true) + for i in tp + @test i == one(BigFloat) + end + + rand_factor = BigFloat(randn(Float64)) + NDTensors.contract!(tp, (1, 2, 3), t, (1, 2, 3), v, (1,), false, rand_factor) + for i in tp + @test i == rand_factor + end + end + @testset "change backends" begin a, b, c = [randn(5, 5) for i in 1:3] backend_auto()