diff --git a/Project.toml b/Project.toml index 718c465..c07427c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OMEinsum" uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922" authors = ["Andreas Peter "] -version = "0.6.8" +version = "0.6.9" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/utils.jl b/src/utils.jl index a15089e..07f041e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -88,12 +88,12 @@ function tensorpermute(A::AbstractArray{T,N}, perm) where {T, N} N == 0 && return copy(A) # group `perm`s permshape = ntuple(i->size(A, @inbounds perm[i]), N) - newshape_slots = ones(Int, N) + newshape_slots = fill(-1, N) dk = 1 # the size of dimension-batch @inbounds begin permk = perm[1] newperm = [permk] - newshape_slots[permk] *= size(A, permk) + newshape_slots[permk] = size(A, permk) end @inbounds for i=2:N permi = perm[i] @@ -102,12 +102,12 @@ function tensorpermute(A::AbstractArray{T,N}, perm) where {T, N} dk += 1 else permk = permi - newshape_slots[permk] *= size(A, permi) + newshape_slots[permk] = size(A, permi) push!(newperm, permk) dk = 1 end end - newshape = filter(!isone, newshape_slots) + newshape = filter(!=(-1), newshape_slots) newperm = sortperm(sortperm(newperm)) A_ = reshape(A, newshape...) A__ = permutedims(A_, newperm) diff --git a/test/einsum.jl b/test/einsum.jl index 24bcd69..88a8a5b 100644 --- a/test/einsum.jl +++ b/test/einsum.jl @@ -246,3 +246,8 @@ end @test OMEinsum.duplicate(x, ix, iy, size_dict) ≈ OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) @test OMEinsum.einsum(Duplicate(), (ix,), iy, (x,), size_dict) ≈ OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) end + +@testset "issue 136" begin + @test EinCode(((1,2,3),(2,)),(1,3))(ones(2,2,1), ones(2)) == reshape([2,2.0], 2, 1) + @test EinCode(((1,2,3),(2,)),(1,3))(ones(2,2,0), ones(2)) == reshape(zeros(0), 2, 0) +end \ No newline at end of file