From 98a77240cdf675b4ca977abf310fd17dfeb2e349 Mon Sep 17 00:00:00 2001 From: Karl Pierce Date: Sat, 12 Oct 2024 21:29:28 -0400 Subject: [PATCH] [ITensors] Fix broken broadcast operation on GPU (#1532) --- NDTensors/test/test_dense.jl | 14 ++++++++++++++ src/broadcast.jl | 11 ++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/NDTensors/test/test_dense.jl b/NDTensors/test/test_dense.jl index 94c52f4132..c2b327811a 100644 --- a/NDTensors/test/test_dense.jl +++ b/NDTensors/test/test_dense.jl @@ -83,6 +83,20 @@ NDTensors.dim(i::MyInd) = i.dim @test A[2, 2] == Aview[1, 1] end + ## Testing A .= α .* B .+ β .* A + C = copy(A) + @allowscalar fill!(B, zero(elt)) + β = elt(2) + α = elt(1) + permutedims!!(A, B, (1, 2), (a, b) -> +(*(β, a), *(α, b))) + @allowscalar 2 .* C == A + randn!(B) + C = copy(A) + A = permutedims!!(A, B, (1, 2), (a, b) -> +(*(β, a), *(α, b))) + @allowscalar for i in 1:3, j in 1:4 + @test A[i, j] == α * B[i, j] + β * C[i, j] + end + ## add elt around 2.0 to preserve the eltype of A. @test data(A * elt(2.0)) == data(elt(2.0) * A) diff --git a/src/broadcast.jl b/src/broadcast.jl index 3e91d64714..4f0848d580 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -395,6 +395,13 @@ end # C .= β .* C .+ α .* A .* B # +struct axpby{Alpha,Beta} <: Function + alpha::Alpha + beta::Beta +end + +(f::axpby)(y, x) = x * f.alpha + y * f.beta + ## TODO this code doesn't actually get called function Base.copyto!( T::ITensor, @@ -414,7 +421,9 @@ function Base.copyto!( A, C = C, A end if !isnothing(A) && !isnothing(C) && !isnothing(α) && !isnothing(β) - map!((r, t) -> β * r + α * t, T, T, A) + # The following fails to compile on some GPU backends. + # map!((r, t) -> β * r + α * t, T, T, A) + map!(axpby(α, β), T, T, A) else bc_bc_α = find_type(Broadcasted, bc_α.args) if isnothing(α)