Skip to content

Commit

Permalink
fix issue 136 (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu authored Jan 10, 2022
1 parent 6fb492e commit cba4c3a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OMEinsum"
uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922"
authors = ["Andreas Peter <[email protected]>"]
version = "0.6.8"
version = "0.6.9"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ function tensorpermute(A::AbstractArray{T,N}, perm) where {T, N}
N == 0 && return copy(A)
# group `perm`s
permshape = ntuple(i->size(A, @inbounds perm[i]), N)
newshape_slots = ones(Int, N)
newshape_slots = fill(-1, N)
dk = 1 # the size of dimension-batch
@inbounds begin
permk = perm[1]
newperm = [permk]
newshape_slots[permk] *= size(A, permk)
newshape_slots[permk] = size(A, permk)
end
@inbounds for i=2:N
permi = perm[i]
Expand All @@ -102,12 +102,12 @@ function tensorpermute(A::AbstractArray{T,N}, perm) where {T, N}
dk += 1
else
permk = permi
newshape_slots[permk] *= size(A, permi)
newshape_slots[permk] = size(A, permi)
push!(newperm, permk)
dk = 1
end
end
newshape = filter(!isone, newshape_slots)
newshape = filter(!=(-1), newshape_slots)
newperm = sortperm(sortperm(newperm))
A_ = reshape(A, newshape...)
A__ = permutedims(A_, newperm)
Expand Down
5 changes: 5 additions & 0 deletions test/einsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,8 @@ end
@test OMEinsum.duplicate(x, ix, iy, size_dict) OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict)
@test OMEinsum.einsum(Duplicate(), (ix,), iy, (x,), size_dict) OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict)
end

@testset "issue 136" begin
@test EinCode(((1,2,3),(2,)),(1,3))(ones(2,2,1), ones(2)) == reshape([2,2.0], 2, 1)
@test EinCode(((1,2,3),(2,)),(1,3))(ones(2,2,0), ones(2)) == reshape(zeros(0), 2, 0)
end

0 comments on commit cba4c3a

Please sign in to comment.