Skip to content

Commit

Permalink
polish diag rule (to fix a problem w.r.t. generic types) (#162)
Browse files Browse the repository at this point in the history
* polish diag rule (due to a test fail in GenericTensorNetworks)

* bump version
  • Loading branch information
GiggleLiu authored Jan 14, 2024
1 parent 8e28e09 commit 6968324
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OMEinsum"
uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922"
authors = ["Andreas Peter <[email protected]>"]
version = "0.8.0"
version = "0.8.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
16 changes: 14 additions & 2 deletions src/unaryrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,20 @@ function compactify!(y, x, ix, iy, sx, sy)
end

function _compactify!(y, x, indexer, sx, sy)
@inbounds for ci in CartesianIndices(y)
y[ci] = sy * y[ci] + sx * x[subindex(indexer, ci.I)]
if iszero(sy)
if isone(sx)
@inbounds for ci in CartesianIndices(y)
y[ci] = x[subindex(indexer, ci.I)]
end
else
@inbounds for ci in CartesianIndices(y)
y[ci] = sx * x[subindex(indexer, ci.I)]
end
end
else
@inbounds for ci in CartesianIndices(y)
y[ci] = sy * y[ci] + sx * x[subindex(indexer, ci.I)]
end
end
return y
end
Expand Down
5 changes: 4 additions & 1 deletion test/unaryrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ end
size_dict = Dict(1=>3,2=>3,3=>3)
x = randn(3,3,3,3,3)
y = randn(3,3,3)
@test unary_einsum!(Diag(), ix, iy, x, y, true, false) OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict)
@test unary_einsum!(Diag(), ix, iy, x, copy(y), true, false) OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict)
@test unary_einsum!(Diag(), ix, iy, x, copy(y), 2.0, 0.0) 2* OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict)
@test unary_einsum!(Diag(), ix, iy, x, copy(y), 1.0, 2.0) OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) + 2y
@test unary_einsum!(Diag(), ix, iy, x, copy(y), 2.0, 3.0) 2 * OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) + 3y
end

@testset "Repeat" begin
Expand Down

0 comments on commit 6968324

Please sign in to comment.