diff --git a/README.md b/README.md index c0f748c..af076a4 100644 --- a/README.md +++ b/README.md @@ -224,6 +224,20 @@ julia> gradient(x->optcode(x,s,s,s,s,s,s,s,s,s)[], s)[1] |> sum ``` This tells us that even if we allow duplicates on one vertex, there are no 3-colourings for the peterson graph. +## Comparison with other packages +Similar packages include: +- [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl) and [TensorKit.jl](https://github.com/Jutho/TensorKit.jl) +- [ITensors.jl](https://github.com/ITensor/ITensors.jl) + +Comparing with the above packages, `OMEinsum` is optimized over large scale tensor network (or einsum, sum-product network) contraction. Its main advantages are: +- `OMEinsum` has better support to very high dimensional tensor networks and their contraction order. +- `OMEinsum` allows an index to appear multiple times. +- `OMEinsum` has well tested generic element type support. + +However, `OMEinsum` also has some disadvantages: +- `OMEinsum` does not support good quantum numbers. +- `OMEinsum` has less optimization on small scale problems. + ## Contribute Suggestions and Comments in the _Issues_ are welcome. diff --git a/benchmark/Project.toml b/benchmark/Project.toml new file mode 100644 index 0000000..a027bf8 --- /dev/null +++ b/benchmark/Project.toml @@ -0,0 +1,3 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" diff --git a/benchmark/large.jl b/benchmark/large.jl new file mode 100644 index 0000000..1dd82b8 --- /dev/null +++ b/benchmark/large.jl @@ -0,0 +1,10 @@ +using OMEinsum, BenchmarkTools + +function benchmark_tensorpermute() + # tensorpermute + A = randn(fill(2, 28)...); + C = zero(A); + perm = [18 22 11 21 15 9 10 19 24 14 5 1 17 20 26 25 28 27 7 6 3 13 12 16 8 23 2 4] |> vec + @btime OMEinsum.tensorpermute!($C, $A, $perm, true, false) evals=3 + nothing +end \ No newline at end of file diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 54de708..8141914 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -1,6 +1,6 @@ module CUDAExt -import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm, asscalar +import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm!, asscalar, @flatten_addmul! using OMEinsum: EinArray, Diag, Repeat, Duplicate, DefaultRule, EinCode, DynamicEinCode, StaticEinCode, NestedEinsum, SimpleBinaryRule, match_rule, loop_einsum, getiy, getixs, _unique, einarray, align_eltypes, siblings, isleaf, tensorindex, _safe_set, rootcode import OMEinsum using LinearAlgebra @@ -15,54 +15,64 @@ _unwrap(x::CuArray) = x asarray(x, arr::CuArray) = CuArray(fill(x, ())) asarray(x::AbstractArray, y::CuArray) = x -asscalar(x::DenseCuArray) = Array(x)[] +asscalar(x::CUDAArrayTypes) = Array(x)[] + +# to avoid returning a ReshapedArray +OMEinsum.safe_reshape(x::CuArray, sz) = reshape(x, (sz...,)) +OMEinsum.safe_reshape(x::Adjoint{T, <:CuArray{T}} where T, sz) = reshape(CuArray(x), (sz...,)) +OMEinsum.safe_reshape(x::Transpose{T, <:CuArray{T}} where T, sz) = reshape(CuArray(x), (sz...,)) Base.Array(x::Base.ReshapedArray{T,0,<:CuArray}) where T = Array(x.parent) -function get_output_array(xs::NTuple{N, DenseCuArray{<:Any,M} where M}, size; has_repeated_indices=true) where N +function get_output_array(xs::NTuple{N, CUDAArrayTypes{<:Any,M} where M}, size; fillzero=true) where N CUDA.zeros(promote_type(map(eltype,xs)...), size...) end -function get_output_array(xs::NTuple{N, DenseCuArray{T,M} where M}, size; has_repeated_indices=true) where {T,N} +function get_output_array(xs::NTuple{N, CUDAArrayTypes{T,M} where M}, size; fillzero=true) where {T,N} CUDA.zeros(T, size...) end CUDA.cudaconvert(A::EinArray{T}) where T = EinArray{T}(cudaconvert.(A.xs), A.x_indexers, A.y_indexer, A.size, A.ICIS, A.OCIS) CUDA.cu(A::EinArray{T}) where T = EinArray{T}(cu.(A.xs), A.x_indexers, A.y_indexer, A.size, A.ICIS, A.OCIS) -for TP in [:Diag, :Repeat, :Duplicate, :DefaultRule] - @eval function einsum(::$TP, ixs, iy, xs::Tuple{<:DenseCuArray}, size_dict::Dict{LT}) where LT - @debug "cueinsum fallback to loop_einsum" rule ixs => iy size.(xs) - loop_einsum(EinCode(ixs, iy), xs, size_dict) +for TP in [:Diag, :Repeat, :Duplicate] + @eval function OMEinsum.unary_einsum!(::$TP, ix, iy, x::CUDAArrayTypes, y::CUDAArrayTypes, sx, sy) + @debug "cueinsum fallback to loop_einsum" rule ix => iy size(x) + size_dict = OMEinsum.get_size_dict((ix, iy), (x, y)) + loop_einsum!((ix,), iy, (x,), y, sx, sy, size_dict) end end -function einsum(::SimpleBinaryRule{('j',), ('j',), ()}, xs::NTuple{2, DenseCuArray}) - dropdims(reshape(xs[1],1,:) * xs[2]; dims=1) -end - -function loop_einsum!(code::EinCode, - xs::NTuple{N, DenseCuArray{<:Any,M} where M}, - y::DenseCuArray{T,L}, size_dict::Dict{LT}) where {N,L,T, LT} - iy = (getiy(code)...,) - ixs = (Tuple.(getixs(code))...,) +function loop_einsum!(ixs0, iy0, + xs::NTuple{N, CUDAArrayTypes{<:Any,M} where M}, + y::CUDAArrayTypes{T,L}, sx, sy, size_dict::Dict{LT}) where {N,L,T, LT} + iy = (iy0...,) + ixs = (Tuple.(ixs0)...,) iy_ = _unique(LT,iy) NO = length(iy_) - A = einarray(Val(ixs), Val(iy), xs, size_dict) + A = einarray(Val(ixs), Val((iy_...,)), xs, size_dict) if NO == length(iy) - y = reshape(y, fill(1, ndims(A)-NO)...,size(y)...) - raw = Base.mapreducedim!(x->x, +, y, A) + raw_ = similar(y, (fill(1, ndims(A)-NO)...,size(y)...,)) + fill!(raw_, zero(T)) + Base.mapreducedim!(x->x, +, raw_, A) if ndims(A)-NO > 0 # fix 1.7 compatibility - raw = dropdims(raw, dims=(1:ndims(A)-NO...,)) + raw = dropdims(raw_, dims=(1:ndims(A)-NO...,)) + else + raw = raw_ end - return raw + return @flatten_addmul! sy * y + sx * raw else - y_ = CUDA.zeros(T, size(A)[end-NO+1:end]...) - y_ = reshape(y_, fill(1, ndims(A)-NO)...,size(y_)...) + if iszero(sy) + fill!(y, zero(T)) + else + lmul!(sy, y) + end + y_ = similar(y, (fill(1, ndims(A)-NO)...,[size_dict[l] for l in iy_]...)) + fill!(y_, zero(T)) raw = Base.mapreducedim!(x->x, +, y_, A) if ndims(A)-NO > 0 # fix 1.7 compatibility raw = dropdims(raw, dims=(1:ndims(A)-NO...,)) end - return expanddims!(Val{((iy_...,),)}(), Val{iy}(), raw, y) + return expanddims!(Val{((iy_...,),)}(), Val{iy}(), raw, y, sx) end end @@ -72,7 +82,7 @@ end Expr(:tuple, ids...) end -function expanddims!(::Val{ixs}, ::Val{iy}, x, y) where {ixs,iy} +function expanddims!(::Val{ixs}, ::Val{iy}, x, y, sx) where {ixs,iy} nthreads = 256 nblocks = cld(prod(size(x)), nthreads) CIS = CartesianIndices(x) @@ -80,23 +90,33 @@ function expanddims!(::Val{ixs}, ::Val{iy}, x, y) where {ixs,iy} i = (blockIdx().x-1) * blockDim().x + threadIdx().x i > length(x) && return nothing @inbounds yi = expandind(Val{ixs}(), Val{iy}(), CIS[i].I) - @inbounds y[CartesianIndex(yi)] = x[i] + @inbounds y[CartesianIndex(yi)] += sx * x[i] nothing end @cuda(blocks=nblocks, threads=nthreads, kernel(y, x)) return y end -function _batched_gemm(C1::Char, C2::Char, A::DenseCuArray{T1,3}, B::DenseCuArray{T2,3}) where {T1<:CuBlasFloat, T2<:CuBlasFloat} - CUDA.CUBLAS.gemm_strided_batched(C1, C2, align_eltypes(A,B)...) -end - -function einsum(::SimpleBinaryRule{(),(), ()}, xs::NTuple{2, DenseCuArray}) - asarray(Array(xs[1])[] * Array(xs[2])[], xs[1]) +function _batched_gemm!(C1::Char, C2::Char, alpha, A::CUDAArrayTypes{T1,3}, B::CUDAArrayTypes{T2,3}, beta, C::CUDAArrayTypes{T3,3}) where {T1<:CuBlasFloat, T2<:CuBlasFloat, T3<:CuBlasFloat} + CUDA.CUBLAS.gemm_strided_batched!(C1, C2, alpha, T1 == T3 ? A : T3.(A), T2 == T3 ? B : T3.(B), beta, C) end Base.ndims(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{0}}) = 0 +function einsum!(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), @nospecialize(y::CUDAArrayTypes), sx, sy, size_dict::Dict; active_free=false) + # do not use map because the static overhead is too large + # do not use `setindex!` because we need to make the AD work + mxs = Vector{AbstractArray}(undef, length(siblings(neinsum))) + for (i, arg) in enumerate(siblings(neinsum)) + mxs = _safe_set(mxs, i, isleaf(arg) ? xs[tensorindex(arg)] : einsum(arg, xs, similar(y, ([size_dict[l] for l in getiy(rootcode(arg))]...,)), true, false, size_dict; active_free=active_free)) + end + res = einsum!(rootcode(neinsum), (mxs...,), y, sx, sy, size_dict) + active_free && for mx in mxs # free CuArray aggressively. + CUDA.unsafe_free!(mx) + end + return res +end + function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), size_dict::Dict; active_free=false) # do not use map because the static overhead is too large # do not use `setindex!` because we need to make the AD work @@ -111,17 +131,6 @@ function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes return res end -# to dispatch Adjoint correctly -@generated function einsum(code::StaticEinCode{LT,ixs, iy}, xs::NTuple{N,CUDAArrayTypes} where N, size_dict::Dict{LT}) where {LT, ixs, iy} - rule = match_rule(ixs, iy) - :(einsum($rule, $ixs, $iy, _unwrap.(xs), size_dict)) -end - -function einsum(code::DynamicEinCode, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), size_dict::Dict) - rule = match_rule(getixs(code), getiy(code)) - einsum(rule, getixs(code), getiy(code), _unwrap.(xs), size_dict) -end - @info("OMEinsum loaded the CUDA module successfully") end \ No newline at end of file diff --git a/src/OMEinsum.jl b/src/OMEinsum.jl index 24bf6e8..31825ac 100644 --- a/src/OMEinsum.jl +++ b/src/OMEinsum.jl @@ -7,7 +7,7 @@ using AbstractTrees import LinearAlgebra: BlasFloat export @ein_str, @ein, ein -export einsum, dynamic_einsum +export einsum!, einsum, dynamic_einsum export EinCode, EinIndexer, EinArray, DynamicEinCode, StaticEinCode, AbstractEinsum, NestedEinsum, SlicedEinsum, DynamicNestedEinsum, StaticNestedEinsum export getiyv, getixsv, uniquelabels, labeltype export flop @@ -33,6 +33,8 @@ include("utils.jl") include("unaryrules.jl") include("binaryrules.jl") +include("matchrule.jl") +include("einsum.jl") include("interfaces.jl") include("einsequence.jl") diff --git a/src/binaryrules.jl b/src/binaryrules.jl index ec82d24..e4f5756 100644 --- a/src/binaryrules.jl +++ b/src/binaryrules.jl @@ -7,96 +7,6 @@ function SimpleBinaryRule(code::EinCode) end SimpleBinaryRule(ix1, ix2, iy) = SimpleBinaryRule{ix1, ix2, iy}() -@inline function _add_patch(::SimpleBinaryRule{ix1,ix2,iy}) where {ix1,ix2,iy} - SimpleBinaryRule{(ix1...,'l'), (ix2...,'l'), (iy...,'l')}() -end -@inline _add_patch(::DefaultRule) = DefaultRule() - -function match_rule_binary(ix1, ix2, iy) - Nx1, Nx2, Ny = length(ix1), length(ix2), length(iy) - if !_isunique(ix1) || !_isunique(ix2) || !_isunique(iy) - DefaultRule() - elseif (Nx1 + Nx2 + Ny) % 2 == 0 # no batch - _match_simple2(ix1,ix2,iy,Nx1,Nx2,Ny) - elseif Nx1>0 && Nx2>0 && Ny>0 && ix1[Nx1]==ix2[Nx2]==iy[Ny] - rule = _match_simple2(ix1,ix2,iy,Nx1-1,Nx2-1,Ny-1) - _add_patch(rule) - else - DefaultRule() - end -end -@inline function _isunique(ix) - if length(ix) <= 1 - return true - elseif length(ix) == 2 - return @inbounds ix[1] != ix[2] - elseif length(ix) == 3 - @inbounds a, b, c = ix - return a != c && a != c && a != b - else # to default rules - return false - end -end - -function _match_simple2(ix1, ix2, iy, Nx1, Nx2, Ny) - if Nx1==0 - if (Ny==Nx2==0) - return SimpleBinaryRule((), (), ()) - elseif (Ny==Nx2==1 && ix2[1] == iy[1]) - return SimpleBinaryRule((), ('k',), ('k',)) - end - elseif Nx1==1 - if (Nx2==0 && Ny==1 && iy[1]==ix1[1]) - return SimpleBinaryRule(('i',), (), ('i',)) - elseif (Nx2==1 && Ny==0 && ix1[1]==ix2[1]) - return SimpleBinaryRule(('j',), ('j',), ()) - elseif Nx2==1 && Ny==2 - if (iy[1]==ix1[1] && iy[2]==ix2[1]) - return SimpleBinaryRule(('i',), ('k',), ('i','k')) - elseif iy[1]==ix2[1] && iy[2]==ix1[1] - return SimpleBinaryRule(('i',), ('k',), ('k','i')) - end - elseif Nx2==2 && Ny==1 - if ix2[1]==ix1[1] && ix2[2]==iy[1] - return SimpleBinaryRule(('j',), ('j','k'), ('k',)) - elseif ix2[1]==iy[1] && ix2[2]==ix1[1] - return SimpleBinaryRule(('j',), ('k','j'), ('k',)) - end - end - elseif Nx1==2 - if Nx2==1 && Ny==1 - if ix1[1]==ix2[1] && ix1[2]==iy[1] - return SimpleBinaryRule(('j','i'), ('j',), ('i',)) - elseif ix1[1]==iy[1] && ix1[2]==ix2[1] - return SimpleBinaryRule(('i','j'), ('j',), ('i',)) - end - elseif (Nx2==2 && Ny==2) - if ix1[1]==ix2[1] && ix1[2]==iy[1] && ix2[2]==iy[2] - return SimpleBinaryRule(('j','i'), ('j','k'), ('i','k')) - elseif ix1[1]==ix2[2] && ix1[2]==iy[1] && ix2[1]==iy[2] - return SimpleBinaryRule(('j','i'), ('k','j'), ('i','k')) - elseif ix1[1]==ix2[1] && ix1[2]==iy[2] && ix2[2]==iy[1] - return SimpleBinaryRule(('j','i'), ('j','k'), ('k','i')) - elseif ix1[1]==ix2[2] && ix1[2]==iy[2] && ix2[1]==iy[1] - return SimpleBinaryRule(('j','i'), ('k','j'), ('k','i')) - elseif ix1[2]==ix2[1] && ix1[1]==iy[1] && ix2[2]==iy[2] - return SimpleBinaryRule(('i','j'), ('j','k'), ('i','k')) - elseif ix1[2]==ix2[2] && ix1[1]==iy[1] && ix2[1]==iy[2] - return SimpleBinaryRule(('i','j'), ('k','j'), ('i','k')) - elseif ix1[2]==ix2[1] && ix1[1]==iy[2] && ix2[2]==iy[1] - return SimpleBinaryRule(('i','j'), ('j','k'), ('k','i')) - elseif ix1[2]==ix2[2] && ix1[1]==iy[2] && ix2[1]==iy[1] - return SimpleBinaryRule(('i','j'), ('k','j'), ('k','i')) - end - end - end - return DefaultRule() -end - -function einsum(rule::SimpleBinaryRule, ixs, iy, xs::NTuple{2, Any}, size_dict) - @debug rule size.(xs) - einsum(rule, xs) -end # Code is a binary representation of `(O1,I,O2,B)`. # Because the time complexity of `GEMM` and `BatchedGEMM` are higher than space complexity, we allow `permutedims`. # We reduce the contraction to these basic forms through `permutedims` and reshape, @@ -105,105 +15,107 @@ end # ,-> : 000 # S = 1 # T = 1 -function einsum(::SimpleBinaryRule{(),(), ()}, xs::NTuple{2, Any}) - asarray(asscalar(xs[1]) * asscalar(xs[2]), xs[1]) +function binary_einsum!(::SimpleBinaryRule{(),(), ()}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * x1 * x2 end # i,->i : 100 # S = N # T = N -function einsum(::SimpleBinaryRule{('i',),(), ('i',)}, xs::NTuple{2, Any}) - xs[1] .* Ref(asscalar(xs[2])) +function binary_einsum!(::SimpleBinaryRule{('i',),(), ('i',)}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * x1 * Ref(asscalar(x2)) end # j,j-> : 010 # S = N # T = N -function einsum(::SimpleBinaryRule{('j',), ('j',), ()}, xs::NTuple{2, Any}) - asarray(transpose(xs[1]) * xs[2], xs[1]) +function binary_einsum!(::SimpleBinaryRule{('j',), ('j',), ()}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * (reshape(x1, 1, length(x1)) * x2) end # ,k->k : 001 # S = N # T = N -@inline function einsum(::SimpleBinaryRule{(), ('k',), ('k',)}, xs::NTuple{2, Any}) - einsum(SimpleBinaryRule{('i',),(),('i',)}(), (xs[2], xs[1])) +@inline function binary_einsum!(::SimpleBinaryRule{(), ('k',), ('k',)}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * Ref(asscalar(x1)) * x2 end # j,jk->k : 011 # S = N^2 # T = N^2 -function einsum(::SimpleBinaryRule{('j',), ('j','k'), ('k',)}, xs::NTuple{2, Any}) - vec(transpose(xs[1]) * xs[2]) +function binary_einsum!(::SimpleBinaryRule{('j',), ('j','k'), ('k',)}, x1, x2, y, sx, sy) + mul!(y, transpose(x2), x1, sx, sy) end -function einsum(::SimpleBinaryRule{('j',), ('k','j'), ('k',)}, xs::NTuple{2, Any}) - xs[2] * xs[1] +function binary_einsum!(::SimpleBinaryRule{('j',), ('k','j'), ('k',)}, x1, x2, y, sx, sy) + mul!(y, x2, x1, sx, sy) end # ij,j->i : 110 # S = N^2 # T = N^2 -@inline function einsum(::SimpleBinaryRule{('i','j'),('j',), ('i',)}, xs::NTuple{2, Any}) - einsum(SimpleBinaryRule{('j',),('k','j'), ('k',)}(), (xs[2], xs[1])) +@inline function binary_einsum!(::SimpleBinaryRule{('i','j'),('j',), ('i',)}, x1, x2, y, sx, sy) + mul!(y, x1, x2, sx, sy) end -@inline function einsum(::SimpleBinaryRule{('j','i'),('j',), ('i',)}, xs::NTuple{2, Any}) - einsum(SimpleBinaryRule{('j',),('j','k'), ('k',)}(), (xs[2], xs[1])) +@inline function binary_einsum!(::SimpleBinaryRule{('j','i'),('j',), ('i',)}, x1, x2, y, sx, sy) + mul!(y, transpose(x1), x2, sx, sy) end # i,k->ik : 101 # S = N^2 # T = N^2 -function einsum(::SimpleBinaryRule{('i',), ('k',), ('i','k')}, xs::NTuple{2, Any}) - xs[1] * transpose(xs[2]) +function binary_einsum!(::SimpleBinaryRule{('i',), ('k',), ('i','k')}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * x1 * reshape(x2, 1, length(x2)) end -@inline function einsum(::SimpleBinaryRule{('i',), ('k',),('k','i')}, xs::NTuple{2, Any}) - einsum(SimpleBinaryRule{('i',),('k',),('i','k')}(), (xs[2], xs[1])) +@inline function binary_einsum!(::SimpleBinaryRule{('i',), ('k',),('k','i')}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * reshape(x1, 1, length(x1)) * x2 end # 000 -function einsum(::SimpleBinaryRule{('l',),('l',), ('l',)}, xs::NTuple{2, Any}) - xs[1] .* xs[2] +function binary_einsum!(::SimpleBinaryRule{('l',),('l',), ('l',)}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * x1 * x2 end # 100 -function einsum(::SimpleBinaryRule{('i','l'),('l',), ('i','l')}, xs::NTuple{2, Any}) - xs[1] .* transpose(xs[2]) +function binary_einsum!(::SimpleBinaryRule{('i','l'),('l',), ('i','l')}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * x1 * reshape(x2, 1, length(x2)) end # 001 -@inline function einsum(::SimpleBinaryRule{('l',), ('k','l'), ('k','l')}, xs::NTuple{2, Any}) - einsum(SimpleBinaryRule{('i','l'),('l',),('i','l')}(), (xs[2], xs[1])) +@inline function binary_einsum!(::SimpleBinaryRule{('l',), ('k','l'), ('k','l')}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * reshape(x1, 1, length(x1)) * x2 end # 010 -function einsum(::SimpleBinaryRule{('j','l'), ('j','l'), ('l',)}, xs::NTuple{2, Any}) - a, b = xs - dropdims(mapreduce(*, +, a, b; dims=1); dims=1) +function binary_einsum!(::SimpleBinaryRule{('j','l'), ('j','l'), ('l',)}, x1, x2, y, sx, sy) + @addmul! sy * y + sx * dropdims(mapreduce(*, +, x1, x2; dims=1); dims=1) end # 101 -function einsum(::SimpleBinaryRule{('i','l'), ('k','l'), ('i','k','l')}, xs::NTuple{2, Any}) - a, b = xs - _batched_gemm('N', 'N', reshape(a, size(a, 1), 1, size(a, 2)), reshape(b, 1, size(b, 1), size(b, 2))) +function binary_einsum!(::SimpleBinaryRule{('i','l'), ('k','l'), ('i','k','l')}, x1, x2, y::AbstractArray, sx, sy) + _batched_gemm!('N', 'N', sx, reshape(x1, size(x1, 1), 1, size(x1, 2)), reshape(x2, 1, size(x2, 1), size(x2, 2)), sy, y) end -@inline function einsum(::SimpleBinaryRule{('i','l'), ('k','l'), ('k','i','l')}, xs::NTuple{2, Any}) - einsum(SimpleBinaryRule{('i','l'),('k','l'), ('i','k','l')}(), (xs[2], xs[1])) +@inline function binary_einsum!(::SimpleBinaryRule{('i','l'), ('k','l'), ('k','i','l')}, x1, x2, y::AbstractArray, sx, sy) + _batched_gemm!('N', 'N', sx, reshape(x2, size(x2, 1), 1, size(x2, 2)), reshape(x1, 1, size(x1, 1), size(x1, 2)), sy, y) end # 011 -function einsum(::SimpleBinaryRule{('j','l'), ('j','k','l'), ('k','l')}, xs::NTuple{2, Any}) - reshape(_batched_gemm('N', 'N', reshape(xs[1], 1, size(xs[1],1), size(xs[1],2)), xs[2]), size(xs[2],2), size(xs[2],3)) +function binary_einsum!(::SimpleBinaryRule{('j','l'), ('j','k','l'), ('k','l')}, x1, x2, y::AbstractArray, sx, sy) + _batched_gemm!('N', 'N', sx, reshape(x1, 1, size(x1,1), size(x1,2)), x2, sy, reshape(y, 1, size(y,1), size(y,2))) + y end -function einsum(::SimpleBinaryRule{('j','l'), ('k','j','l'), ('k','l')}, xs::NTuple{2, Any}) - reshape(_batched_gemm('N', 'T', reshape(xs[1], 1, size(xs[1],1), size(xs[1],2)), xs[2]), size(xs[2],1), size(xs[2],3)) +function binary_einsum!(::SimpleBinaryRule{('j','l'), ('k','j','l'), ('k','l')}, x1, x2, y::AbstractArray, sx, sy) + _batched_gemm!('N', 'T', sx, reshape(x1, 1, size(x1,1), size(x1,2)), x2, sy, reshape(y, 1, size(y,1), size(y,2))) + y end # 110 -function einsum(::SimpleBinaryRule{('i','j','l'), ('j','l'), ('i','l')}, xs::NTuple{2, Any}) - reshape(_batched_gemm('N', 'N', xs[1], reshape(xs[2], size(xs[2],1), 1, size(xs[2],2))), size(xs[1],1), size(xs[1],3)) +function binary_einsum!(::SimpleBinaryRule{('i','j','l'), ('j','l'), ('i','l')}, x1, x2, y::AbstractArray, sx, sy) + _batched_gemm!('N', 'N', sx, x1, reshape(x2, size(x2,1), 1, size(x2,2)), sy, reshape(y, size(y,1), 1, size(y,2))) + y end -function einsum(::SimpleBinaryRule{('j','i','l'), ('j','l'), ('i','l')}, xs::NTuple{2, Any}) - reshape(_batched_gemm('T', 'N', xs[1], reshape(xs[2], size(xs[2],1), 1, size(xs[2],2))), size(xs[1],2), size(xs[1],3)) +function binary_einsum!(::SimpleBinaryRule{('j','i','l'), ('j','l'), ('i','l')}, x1, x2, y::AbstractArray, sx, sy) + _batched_gemm!('T', 'N', sx, x1, reshape(x2, size(x2,1), 1, size(x2,2)), sy, reshape(y, size(y,1), 1, size(y,2))) + y end # ij,jk->ik : 111 @@ -212,142 +124,19 @@ end for (i1, X1) in enumerate([('i', 'j'), ('j', 'i')]) for (i2, X2) in enumerate([('j', 'k'), ('k', 'j')]) for (i3, X3) in enumerate([('i', 'k'), ('k', 'i')]) - A1 = i1==i3 ? :(xs[1]) : :(transpose(xs[1])) - A2 = i2==i3 ? :(xs[2]) : :(transpose(xs[2])) - @eval function einsum(::SimpleBinaryRule{$X1,$X2, $X3}, xs::NTuple{2, Any}) - $(i3==1 ? :($A1*$A2) : :($A2*$A1)) + A1 = i1==i3 ? :(x1) : :(transpose(x1)) + A2 = i2==i3 ? :(x2) : :(transpose(x2)) + @eval function binary_einsum!(::SimpleBinaryRule{$X1,$X2, $X3}, x1, x2, y::AbstractArray{T}, sx, sy) where T + $(i3==1 ? :(mul!(y, $A1, $A2, sx, sy)) : :(mul!(y, $A2, $A1, sx, sy))) end X1B = (X1...,'l') X2B = (X2...,'l') X3B = (X3...,'l') C1 = i1==i3 ? 'N' : 'T' C2 = i2==i3 ? 'N' : 'T' - @eval function einsum(::SimpleBinaryRule{$X1B,$X2B,$X3B}, xs::NTuple{2, Any}) - $(i3==1 ? :(_batched_gemm($C1, $C2, xs[1], xs[2])) : :(_batched_gemm($C2, $C1, xs[2], xs[1]))) + @eval function binary_einsum!(::SimpleBinaryRule{$X1B,$X2B,$X3B}, x1, x2, y::AbstractArray{T}, sx, sy) where T + $(i3==1 ? :(_batched_gemm!($C1, $C2, sx, x1, x2, sy, y)) : :(_batched_gemm!($C2, $C1, sx, x2, x1, sy, y))) end end end -end - -# there are too many combination in the binary case, so nospecialize -function einsum(::DefaultRule, ixs, iy, @nospecialize(xs::NTuple{2, Any}), size_dict::Dict{LT}) where LT - @debug "DefaultRule binary" ixs => iy size.(xs) - ix1, ix2 = ixs - x1, x2 = xs - c1, c2, cy, s1, s2, i1, i2, iyb = analyze_binary(_collect(LT,ix1), _collect(LT,ix2), _collect(LT,iy), size_dict) - rule = SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}() - x1 = simplify_unary(_collect(LT,ix1), c1, x1, size_dict) - x2 = simplify_unary(_collect(LT,ix2), c2, x2, size_dict) - x1_ = reshape(x1, s1...) - x2_ = reshape(x2, s2...) - @debug rule size.((x1_, x2_)) - y_ = reshape(einsum(rule, (x1_, x2_)), [size_dict[x] for x in cy]...) - return expand_unary(cy, _collect(LT,iy), y_, size_dict) -end - -function simplify_unary(ix::Vector{T}, iy::Vector{T}, x, size_dict::Dict{T}) where T - if ix == iy - return x - elseif length(ix) == length(iy) # permutation - return einsum(Permutedims(), (ix,), iy, (x,), size_dict) - else - # diag - ix_ = unique(ix) - x_ = length(ix_) != length(ix) ? einsum(Diag(), (ix,), ix_, (x,), size_dict) : x - # sum - if length(ix_) != length(iy) - return einsum(Sum(), (ix_,), iy, (x_,), size_dict) - elseif ix_ != iy - return einsum(Permutedims(), (ix_,), iy, (x_,), size_dict) - else - return x_ - end - end -end - -function expand_unary(ix::Vector{T}, iy::Vector{T}, x::AbstractArray, size_dict::Dict{T}) where T - iy_b = unique(iy) - iy_a = filter(i->i ∈ ix, iy_b) - y_a = if ix != iy_a - einsum(Permutedims(), (ix,), iy_a, (x,), size_dict) - else - x - end - # repeat - y_b = length(iy_a) != length(iy_b) ? einsum(Repeat(), (iy_a,), iy_b, (y_a,), size_dict) : y_a - # duplicate - length(iy_b) != length(iy) ? einsum(Duplicate(), (iy_b,), iy, (y_b,), size_dict) : y_b -end - -""" -Get the expected labels. -""" -function analyze_binary(ix1::Vector{T}, ix2::Vector{T}, iy::Vector{T}, size_dict::Dict{T,Int}) where T - ix_inner, ix1_outer, ix2_outer, batch = _analyze_binary_input(ix1, ix2, iy) - c1 = vcat(ix1_outer, ix_inner, batch) - c2 = vcat(ix_inner, ix2_outer, batch) - cy = vcat(ix1_outer, ix2_outer, batch) - si = prod(map(x->size_dict[x], ix1_outer)) - sj = prod(map(x->size_dict[x], ix_inner)) - sk = prod(map(x->size_dict[x], ix2_outer)) - sl = prod(map(x->size_dict[x], batch)) - has_i = !isempty(ix1_outer) - has_j = !isempty(ix_inner) - has_k = !isempty(ix2_outer) - has_l = !isempty(batch) - i1 = Char[] - i2 = Char[] - iyb = Char[] - s1 = Int[] - s2 = Int[] - if has_i - push!(i1, 'i') - push!(iyb, 'i') - push!(s1, si) - end - if has_j - push!(i1, 'j') - push!(i2, 'j') - push!(s1, sj) - push!(s2, sj) - end - if has_k - push!(i2, 'k') - push!(iyb, 'k') - push!(s2, sk) - end - if has_l - push!(i1, 'l') - push!(i2, 'l') - push!(iyb, 'l') - push!(s1, sl) - push!(s2, sl) - end - return c1, c2, cy, s1, s2, i1, i2, iyb -end - -function _analyze_binary_input(ix1::Vector{T}, ix2::Vector{T}, iy::Vector{T}) where T - ix1_batch = T[] - ix1_inner = T[] - ix1_outer = T[] - for l1 in ix1 - if l1 ∈ ix2 - if l1 ∈ iy # batch - l1 ∉ ix1_batch && push!(ix1_batch, l1) - else # inner - l1 ∉ ix1_inner && push!(ix1_inner, l1) - end - elseif l1 ∈ iy # outer - l1 ∉ ix1_outer && push!(ix1_outer, l1) - else - # dangling - end - end - ix2_outer = T[] # outer dimension of x2 - for l2 in ix2 - if l2 ∉ ix1 && l2 ∈ iy && l2 ∉ ix2_outer - push!(ix2_outer, l2) - end - end - ix1_inner, ix1_outer, ix2_outer, ix1_batch -end +end \ No newline at end of file diff --git a/src/einsequence.jl b/src/einsequence.jl index 3eae41a..ad87c64 100644 --- a/src/einsequence.jl +++ b/src/einsequence.jl @@ -267,6 +267,15 @@ function get_size_dict!(ne::NestedEinsum, @nospecialize(xs), size_info::Dict{LT} return get_size_dict_!(ixs, [collect(Int, size(xs[i])) for i in ks], size_info) end +function einsum!(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), y, sx, sy, size_dict::Dict) + # do not use map because the static overhead is too large + # do not use `setindex!` because we need to make the AD work + mxs = Vector{AbstractArray}(undef, length(siblings(neinsum))) + for (i, arg) in enumerate(siblings(neinsum)) + mxs = _safe_set(mxs, i, isleaf(arg) ? xs[tensorindex(arg)] : einsum!(arg, xs, similar(y, ([size_dict[l] for l in getiy(rootcode(arg))]...,)), true, false, size_dict)) + end + return einsum!(rootcode(neinsum), (mxs...,), y, sx, sy, size_dict) +end function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), size_dict::Dict) # do not use map because the static overhead is too large # do not use `setindex!` because we need to make the AD work diff --git a/src/einsum.jl b/src/einsum.jl new file mode 100644 index 0000000..b565ca2 --- /dev/null +++ b/src/einsum.jl @@ -0,0 +1,205 @@ +## non-inplace einsum +@doc raw" + einsum(code::EinCode, xs, size_dict) + einsum(rule, ixs, iy, xs, size_dict) + +return the tensor that results from contracting the tensors `xs` according +to their indices `ixs` (`getixs(code)`), where all indices that do not appear in the output `iy` (`getiy(code)`) are +summed over. +The result is permuted according to `out`. + +- `ixs` - tuple of tuples of index-labels of the input-tensors `xs` + +- `iy` - tuple of index-labels of the output-tensor + +- `xs` - tuple of tensors + +- `size_dict` - a dictionary that maps index-labels to their sizes + +# example + +```jldoctest; setup = :(using OMEinsum) +julia> a, b = rand(2,2), rand(2,2); + +julia> einsum(EinCode((('i','j'),('j','k')),('i','k')), (a, b)) ≈ a * b +true + +julia> einsum(EinCode((('i','j'),('j','k')),('k','i')), (a, b)) ≈ permutedims(a * b, (2,1)) +true +``` +" +function einsum(code::AbstractEinsum, @nospecialize(xs::Tuple), size_dict::Dict = get_size_dict!(getixs(code), xs, Dict{labeltype(code),Int}())) + y = get_output_array(xs, map(y->size_dict[y],getiyv(code)); fillzero=false) + einsum!(code, xs, y, true, false, size_dict) +end + +# inplace einsum, EinCode as the input +function einsum!(code::EinCode, @nospecialize(xs::Tuple), @nospecialize(y), sx, sy, size_dict::Dict) + einsum!(getixs(code), getiy(code), xs, y, sx, sy, size_dict) +end +# inplace einsum, the fallback +function einsum!(ixs, iy, @nospecialize(xs::Tuple), @nospecialize(y), sx, sy, size_dict::Dict) + @debug "fallback to loop_einsum" ixs => iy size.(xs) + loop_einsum!(ixs, iy, (xs...,), y, sx, sy, size_dict) +end + +struct UnaryOperation{LT} + type + ix::Vector{LT} + iy::Vector{LT} +end +# for unary operations +# overhead ~ 2.3us +# @benchmark OMEinsum.einsum(DefaultRule(), $((('a', 'a', 'b'),)), $(('c', 'b','a')), (x,), $(Dict('a'=>1, 'b'=>1, 'c'=>1))) setup=(x=randn(1,1,1)) +function unary_pipeline(ix::Vector{LT}, iy::Vector{LT}) where LT + ix_unique = _unique(LT, ix) + iy_unique = _unique(LT, iy) + iy_a = filter(i->i ∈ ix, iy_unique) + + operations = UnaryOperation[] + if length(ix_unique) != length(ix) # diag + push!(operations, UnaryOperation(Diag(), ix, ix_unique)) + end + if length(ix_unique) != length(iy_a) # sum + push!(operations, UnaryOperation(Sum(), ix_unique, iy_a)) + elseif ix_unique != iy_a # permute, high freq + push!(operations, UnaryOperation(Permutedims(), ix_unique, iy_a)) + end + + if length(iy_a) != length(iy_unique) # repeat + push!(operations, UnaryOperation(Repeat(), iy_a, iy_unique)) + end + if length(iy_unique) != length(iy) # duplicate + push!(operations, UnaryOperation(Duplicate(), iy_unique, iy)) + end + return operations +end + +function einsum!(ixs, iy, @nospecialize(xs::NTuple{1, Any}), @nospecialize(y), sx, sy, size_dict::Dict{LT}) where LT + @debug "compiling unary" ixs[1] => iy size(xs[1]) + pipeline = unary_pipeline(collect(LT, ixs[1]), collect(LT, iy)) + lasttensor = xs[1] + for (k, op) in enumerate(pipeline) + if k == length(pipeline) # last operation + unary_einsum!(op.type, op.ix, op.iy, lasttensor, y, sx, sy) + else + cache = similar(y, ([size_dict[l] for l in op.iy]...,)) + unary_einsum!(op.type, op.ix, op.iy, lasttensor, cache, true, false) + lasttensor = cache + end + end + if length(pipeline) == 0 + @flatten_addmul! sy * y + sx * lasttensor + end + return y +end + +# there are too many combination in the binary case, so nospecialize +function einsum!(ixs, iy, @nospecialize(xs::NTuple{2, Any}), @nospecialize(y), sx, sy, size_dict::Dict{LT}) where LT + iyv = _collect(LT,iy) + ix1v, ix2v = _collect.(Ref(LT), ixs) + @debug "compiling binary" ixs => iyv size.(xs) + x1, x2 = xs + c1, c2, cy, s1, s2, s3, i1, i2, iyb = analyze_binary(ix1v, ix2v, iyv, size_dict) + rule = SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}() + xs1 = simplifyto(ix1v, c1, x1, size_dict) + xs2 = simplifyto(ix2v, c2, x2, size_dict) + x1_ = safe_reshape(xs1, s1) + x2_ = safe_reshape(xs2, s2) + @debug rule size.((x1_, x2_)) + if cy != iyv + y_ = similar(y, (s3...,)) + y_ = reshape(binary_einsum!(rule, x1_, x2_, y_, true, false), [size_dict[x] for x in cy]...) + return einsum!((cy,), iyv, (y_,), y, sx, sy, size_dict) + else + binary_einsum!(rule, x1_, x2_, safe_reshape(y, s3), sx, sy) + return y + end +end +safe_reshape(x, sz) = reshape(x, (sz...,)) + +function simplifyto(ix1, c1, x1, size_dict::Dict{LT}) where LT + if c1 != ix1 + xs1 = similar(x1, ([size_dict[l] for l in c1]...,)) + return einsum!((_collect(LT,ix1),), c1, (x1,), xs1, true, false, size_dict) + else + return x1 + end +end + +""" +Get the expected labels. +""" +function analyze_binary(ix1::Vector{T}, ix2::Vector{T}, iy::Vector{T}, size_dict::Dict{T,Int}) where T + ix_inner, ix1_outer, ix2_outer, batch = _analyze_binary_input(ix1, ix2, iy) + c1 = vcat(ix1_outer, ix_inner, batch) + c2 = vcat(ix_inner, ix2_outer, batch) + cy = vcat(ix1_outer, ix2_outer, batch) + si = prod(map(x->size_dict[x], ix1_outer)) + sj = prod(map(x->size_dict[x], ix_inner)) + sk = prod(map(x->size_dict[x], ix2_outer)) + sl = prod(map(x->size_dict[x], batch)) + has_i = !isempty(ix1_outer) + has_j = !isempty(ix_inner) + has_k = !isempty(ix2_outer) + has_l = !isempty(batch) + i1 = Char[] + i2 = Char[] + iyb = Char[] + s1 = Int[] + s2 = Int[] + s3 = Int[] + if has_i + push!(i1, 'i') + push!(iyb, 'i') + push!(s1, si) + push!(s3, si) + end + if has_j + push!(i1, 'j') + push!(i2, 'j') + push!(s1, sj) + push!(s2, sj) + end + if has_k + push!(i2, 'k') + push!(iyb, 'k') + push!(s2, sk) + push!(s3, sk) + end + if has_l + push!(i1, 'l') + push!(i2, 'l') + push!(iyb, 'l') + push!(s1, sl) + push!(s2, sl) + push!(s3, sl) + end + return c1, c2, cy, s1, s2, s3, i1, i2, iyb +end + +function _analyze_binary_input(ix1::Vector{T}, ix2::Vector{T}, iy::Vector{T}) where T + ix1_batch = T[] + ix1_inner = T[] + ix1_outer = T[] + for l1 in ix1 + if l1 ∈ ix2 + if l1 ∈ iy # batch + l1 ∉ ix1_batch && push!(ix1_batch, l1) + else # inner + l1 ∉ ix1_inner && push!(ix1_inner, l1) + end + elseif l1 ∈ iy # outer + l1 ∉ ix1_outer && push!(ix1_outer, l1) + else + # dangling + end + end + ix2_outer = T[] # outer dimension of x2 + for l2 in ix2 + if l2 ∉ ix1 && l2 ∈ iy && l2 ∉ ix2_outer + push!(ix2_outer, l2) + end + end + ix1_inner, ix1_outer, ix2_outer, ix1_batch +end \ No newline at end of file diff --git a/src/interfaces.jl b/src/interfaces.jl index 036d0a4..a91b24a 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -99,6 +99,9 @@ end LT = foldl((a, b) -> promote_type(a, eltype(b)), ixs; init=Union{}) return get_size_dict!(ixs, xs, size_info===nothing ? Dict{LT,Int}() : size_info) end +@inline function get_size_dict(ixs::AbstractVector{<:AbstractVector{LT}}, xs, size_info=nothing) where LT + return get_size_dict!(ixs, xs, size_info===nothing ? Dict{LT,Int}() : size_info) +end using MacroTools """ @@ -158,53 +161,4 @@ function _ein_macro(ex; einsum=:einsum) rightnames = [ esc(A) for (A, ind) in rightpairs ] return :( $(esc(Z)) = $einsum( EinCode(($(righttuples...),), $lefttuple), ($(rightnames...),)) ) -end - -@doc raw" - einsum(code::EinCode, xs, size_dict) - einsum(rule, ixs, iy, xs, size_dict) - -return the tensor that results from contracting the tensors `xs` according -to their indices `ixs` (`getixs(code)`), where all indices that do not appear in the output `iy` (`getiy(code)`) are -summed over. -The result is permuted according to `out`. - -- `ixs` - tuple of tuples of index-labels of the input-tensors `xs` - -- `iy` - tuple of index-labels of the output-tensor - -- `xs` - tuple of tensors - -- `size_dict` - a dictionary that maps index-labels to their sizes - -# example - -```jldoctest; setup = :(using OMEinsum) -julia> a, b = rand(2,2), rand(2,2); - -julia> einsum(EinCode((('i','j'),('j','k')),('i','k')), (a, b)) ≈ a * b -true - -julia> einsum(EinCode((('i','j'),('j','k')),('k','i')), (a, b)) ≈ permutedims(a * b, (2,1)) -true -``` -" -@generated function einsum(code::StaticEinCode{LT, ixs, iy}, xs::Tuple, size_dict::Dict{LT}) where {LT, ixs, iy} - rule = match_rule(ixs, iy) - :(einsum($rule, $ixs, $iy, xs, size_dict)) -end - -function einsum(code::DynamicEinCode, @nospecialize(xs::Tuple), size_dict::Dict) - rule = match_rule(getixs(code), getiy(code)) - einsum(rule, getixs(code), getiy(code), xs, size_dict) -end - -function einsum(code::EinCode, @nospecialize(xs::Tuple)) - einsum(code, xs, get_size_dict!(getixs(code), xs, Dict{labeltype(code),Int}())) -end - -# the fallback -function einsum(::DefaultRule, ixs, iy, xs::Tuple, size_dict) - @debug "DefaultRule loop_einsum" ixs => iy size.(xs) - loop_einsum(EinCode(ixs, iy), (xs...,), size_dict) -end +end \ No newline at end of file diff --git a/src/loop_einsum.jl b/src/loop_einsum.jl index 665eb59..4233f67 100644 --- a/src/loop_einsum.jl +++ b/src/loop_einsum.jl @@ -9,24 +9,29 @@ function loop_einsum(code::EinCode, xs::NTuple{N, AbstractArray{<:Any,M} where M size_dict) where {N} iy = getiy(code) size = getindex.(Ref(size_dict), iy) - loop_einsum!(code, xs, get_output_array(xs, size; has_repeated_indices=!allunique(iy)), size_dict) + loop_einsum!(getixs(code), getiy(code), xs, get_output_array(xs, size; fillzero=false), true, false, size_dict) end """ - loop_einsum!(::EinCode, xs, y, size_dict) + loop_einsum!(ixs, iy, xs, y, sx, sy, size_dict) inplace-version of `loop_einsum`, saving the result in a preallocated tensor of correct size `y`. """ -function loop_einsum!(code::EinCode, +function loop_einsum!(ixs, iy, xs::NTuple{N, AbstractArray{<:Any,M} where M}, - y::AbstractArray{T,L}, size_dict) where {N,L,T} - ALLOW_LOOPS[] || error("using `loop_einsum` is forbidden: code: $code") - A = einarray(Val((Tuple.(getixs(code))...,)), Val((getiy(code)...,)), xs, size_dict) - reduce_einarray!(A, y) + y::AbstractArray{T,L}, sx, sy, size_dict) where {N,L,T} + ALLOW_LOOPS[] || error("using `loop_einsum` is forbidden: code: $ixs -> $iy") + A = einarray(Val((Tuple.(ixs)...,)), Val((iy...,)), xs, size_dict) + if iszero(sy) + fill!(y, zero(T)) + elseif !isone(sy) + lmul!(sy, y) + end + reduce_einarray!(A, y, sx) end -function reduce_einarray!(A::EinArray{T}, y) where T +function reduce_einarray!(A::EinArray{T}, y, sx) where T @inbounds for ind_y in A.OCIS iy = subindex(A.y_indexer,ind_y) yi = zero(T) @@ -34,21 +39,21 @@ function reduce_einarray!(A::EinArray{T}, y) where T ind = TupleTools.vcat(ind_x.I,ind_y.I) yi += map_prod(A.xs, ind, A.x_indexers) end - y[iy] = yi + y[iy] += sx * yi end y end # speed up the get output array for the case when the inputs have the same type. -function get_output_array(xs::NTuple{N, AbstractArray{T,M} where M}, size; has_repeated_indices=true) where {T,N} - if has_repeated_indices +function get_output_array(xs::NTuple{N, AbstractArray{T,M} where M}, size; fillzero=true) where {T,N} + if fillzero zeros(T, size...) else Array{T}(undef, size...) end end -function get_output_array(xs::NTuple{N, AbstractArray{<:Any,M} where M}, size; has_repeated_indices=true) where N - if has_repeated_indices +function get_output_array(xs::NTuple{N, AbstractArray{<:Any,M} where M}, size; fillzero=true) where N + if fillzero zeros(promote_type(map(eltype,xs)...), size...) else Array{promote_type(map(eltype,xs)...)}(undef, size...) diff --git a/src/matchrule.jl b/src/matchrule.jl new file mode 100644 index 0000000..82f17ad --- /dev/null +++ b/src/matchrule.jl @@ -0,0 +1,149 @@ +@doc raw" + match_rule(ixs, iy) + match_rule(code::EinCode) + +Returns the rule that matches, otherwise use `DefaultRule` - the slow `loop_einsum` backend. +" +function match_rule(ixs, iy) + if length(ixs) == 1 + return match_rule_unary(ixs[1], iy) + elseif length(ixs) == 2 + return match_rule_binary(ixs[1], ixs[2], iy) + else + return DefaultRule() + end +end + +function match_rule_unary(ix, iy) + Nx = length(ix) + Ny = length(iy) + # the first rule with the higher the priority + if Ny == 0 && Nx == 2 && ix[1] == ix[2] + return Tr() + elseif allunique(iy) + if ix == iy + return Identity() + elseif allunique(ix) + if Nx == Ny + if all(i -> i in iy, ix) + return Permutedims() + else # e.g. (abcd->bcde) + return DefaultRule() + end + else + if all(i -> i in ix, iy) + return Sum() + elseif all(i -> i in iy, ix) # e.g. ij->ijk + return Repeat() + else # e.g. ijkxc,ijkl + return DefaultRule() + end + end + else # ix is not unique + if all(i -> i in ix, iy) && all(i -> i in iy, ix) # ijjj->ij + return Diag() + else + return DefaultRule() + end + end + else # iy is not unique + if allunique(ix) && all(x->x∈iy, ix) + if all(y->y∈ix, iy) # e.g. ij->ijjj + return Duplicate() + else # e.g. ij->ijjl + return DefaultRule() + end + else + return DefaultRule() + end + end +end + +match_rule(code::EinCode) = match_rule(getixs(code), getiy(code)) + + +@inline function _add_batch(::SimpleBinaryRule{ix1,ix2,iy}) where {ix1,ix2,iy} + SimpleBinaryRule{(ix1...,'l'), (ix2...,'l'), (iy...,'l')}() +end +@inline _add_batch(::DefaultRule) = DefaultRule() + +function match_rule_binary(ix1, ix2, iy) + Nx1, Nx2, Ny = length(ix1), length(ix2), length(iy) + if !_isunique(ix1) || !_isunique(ix2) || !_isunique(iy) + DefaultRule() + elseif (Nx1 + Nx2 + Ny) % 2 == 0 # no batch + _match_simple2(ix1,ix2,iy,Nx1,Nx2,Ny) + elseif Nx1>0 && Nx2>0 && Ny>0 && ix1[Nx1]==ix2[Nx2]==iy[Ny] + rule = _match_simple2(ix1,ix2,iy,Nx1-1,Nx2-1,Ny-1) + _add_batch(rule) + else + DefaultRule() + end +end +@inline function _isunique(ix) + if length(ix) <= 1 + return true + elseif length(ix) == 2 + return @inbounds ix[1] != ix[2] + elseif length(ix) == 3 + @inbounds a, b, c = ix + return a != c && a != c && a != b + else # to default rules + return false + end +end + +function _match_simple2(ix1, ix2, iy, Nx1, Nx2, Ny) + if Nx1==0 + if (Ny==Nx2==0) + return SimpleBinaryRule((), (), ()) + elseif (Ny==Nx2==1 && ix2[1] == iy[1]) + return SimpleBinaryRule((), ('k',), ('k',)) + end + elseif Nx1==1 + if (Nx2==0 && Ny==1 && iy[1]==ix1[1]) + return SimpleBinaryRule(('i',), (), ('i',)) + elseif (Nx2==1 && Ny==0 && ix1[1]==ix2[1]) + return SimpleBinaryRule(('j',), ('j',), ()) + elseif Nx2==1 && Ny==2 + if (iy[1]==ix1[1] && iy[2]==ix2[1]) + return SimpleBinaryRule(('i',), ('k',), ('i','k')) + elseif iy[1]==ix2[1] && iy[2]==ix1[1] + return SimpleBinaryRule(('i',), ('k',), ('k','i')) + end + elseif Nx2==2 && Ny==1 + if ix2[1]==ix1[1] && ix2[2]==iy[1] + return SimpleBinaryRule(('j',), ('j','k'), ('k',)) + elseif ix2[1]==iy[1] && ix2[2]==ix1[1] + return SimpleBinaryRule(('j',), ('k','j'), ('k',)) + end + end + elseif Nx1==2 + if Nx2==1 && Ny==1 + if ix1[1]==ix2[1] && ix1[2]==iy[1] + return SimpleBinaryRule(('j','i'), ('j',), ('i',)) + elseif ix1[1]==iy[1] && ix1[2]==ix2[1] + return SimpleBinaryRule(('i','j'), ('j',), ('i',)) + end + elseif (Nx2==2 && Ny==2) + if ix1[1]==ix2[1] && ix1[2]==iy[1] && ix2[2]==iy[2] + return SimpleBinaryRule(('j','i'), ('j','k'), ('i','k')) + elseif ix1[1]==ix2[2] && ix1[2]==iy[1] && ix2[1]==iy[2] + return SimpleBinaryRule(('j','i'), ('k','j'), ('i','k')) + elseif ix1[1]==ix2[1] && ix1[2]==iy[2] && ix2[2]==iy[1] + return SimpleBinaryRule(('j','i'), ('j','k'), ('k','i')) + elseif ix1[1]==ix2[2] && ix1[2]==iy[2] && ix2[1]==iy[1] + return SimpleBinaryRule(('j','i'), ('k','j'), ('k','i')) + elseif ix1[2]==ix2[1] && ix1[1]==iy[1] && ix2[2]==iy[2] + return SimpleBinaryRule(('i','j'), ('j','k'), ('i','k')) + elseif ix1[2]==ix2[2] && ix1[1]==iy[1] && ix2[1]==iy[2] + return SimpleBinaryRule(('i','j'), ('k','j'), ('i','k')) + elseif ix1[2]==ix2[1] && ix1[1]==iy[2] && ix2[2]==iy[1] + return SimpleBinaryRule(('i','j'), ('j','k'), ('k','i')) + elseif ix1[2]==ix2[2] && ix1[1]==iy[2] && ix2[1]==iy[1] + return SimpleBinaryRule(('i','j'), ('k','j'), ('k','i')) + end + end + end + return DefaultRule() +end \ No newline at end of file diff --git a/src/unaryrules.jl b/src/unaryrules.jl index 84693bf..95f2b2e 100644 --- a/src/unaryrules.jl +++ b/src/unaryrules.jl @@ -16,142 +16,79 @@ struct Duplicate <: EinRule{1} end struct Diag <: EinRule{1} end struct DefaultRule <: EinRule{Any} end -@doc raw" - match_rule(ixs, iy) - match_rule(code::EinCode) - -Returns the rule that matches, otherwise use `DefaultRule` - the slow `loop_einsum` backend. -" -function match_rule(ixs, iy) - if length(ixs) == 1 - return match_rule_unary(ixs[1], iy) - elseif length(ixs) == 2 - return match_rule_binary(ixs[1], ixs[2], iy) - else - return DefaultRule() - end -end - -function match_rule_unary(ix, iy) - Nx = length(ix) - Ny = length(iy) - # the first rule with the higher the priority - if Ny == 0 && Nx == 2 && ix[1] == ix[2] - return Tr() - elseif allunique(iy) - if ix == iy - return Identity() - elseif allunique(ix) - if Nx == Ny - if all(i -> i in iy, ix) - return Permutedims() - else # e.g. (abcd->bcde) - return DefaultRule() - end - else - if all(i -> i in ix, iy) - return Sum() - elseif all(i -> i in iy, ix) # e.g. ij->ijk - return Repeat() - else # e.g. ijkxc,ijkl - return DefaultRule() - end - end - else # ix is not unique - if all(i -> i in ix, iy) && all(i -> i in iy, ix) # ijjj->ij - return Diag() - else - return DefaultRule() - end - end - else # iy is not unique - if allunique(ix) && all(x->x∈iy, ix) - if all(y->y∈ix, iy) # e.g. ij->ijjj - return Duplicate() - else # e.g. ij->ijjl - return DefaultRule() - end - else - return DefaultRule() - end - end -end - -match_rule(code::EinCode) = match_rule(getixs(code), getiy(code)) - # trace # overhead ~ 0.07us # @benchmark OMEinsum.einsum(Tr(), $(('a', 'a')), $(()), x, $(Dict('a'=>1, 'b'=>1))) setup=(x=randn(1,1)) -function einsum(::Tr, ixs, iy, xs::Tuple{<:AbstractArray}, size_dict::Dict) - x = xs[1] +function unary_einsum!(::Tr, ix, iy, x, y::AbstractArray, sx, sy) @debug "Tr" size(x) - asarray(tr(x), x) + y .= sy .* y .+ sx * tr(x) + return y end # overhead ~ 0.55us # @benchmark OMEinsum.einsum(Sum(), $(('a', 'b')), $(('b',)), x, $(Dict('a'=>1, 'b'=>1))) setup=(x=randn(1,1)) -function einsum(::Sum, ixs, iy, xs::Tuple{<:AbstractArray}, size_dict::Dict{LT}) where LT - ix, x = ixs[1], xs[1] +function unary_einsum!(::Sum, ix, iy, x::AbstractArray, y::AbstractArray, sx, sy) @debug "Sum" ix => iy size(x) dims = (findall(i -> i ∉ iy, ix)...,)::NTuple{length(ix)-length(iy),Int} res = dropdims(sum(x, dims=dims), dims=dims) ix1f = filter(i -> i ∈ iy, ix)::typeof(iy) if ix1f != iy - return einsum(Permutedims(), ((ix1f...,),), iy, (res,), size_dict) + return unary_einsum!(Permutedims(), (ix1f...,), iy, res, y, sx, sy) else - return res + return @flatten_addmul! sy * y + sx * res end end # overhead ~ 0.53us # @benchmark OMEinsum.einsum(OMEinsum.Repeat(), $(('a',)), $(('a', 'b',)), x, $(Dict('a'=>1, 'b'=>1))) setup=(x=randn(1)) -function einsum(::Repeat, ixs, iy, xs::Tuple{<:AbstractArray}, size_dict::Dict) - ix, x = ixs[1], xs[1] +function unary_einsum!(::Repeat, ix, iy, x::AbstractArray, y::AbstractArray, sx, sy) @debug "Repeat" ix => iy size(x) ix1f = filter(i -> i ∈ ix, iy) - res = if ix1f != ix - einsum(Permutedims(), (ix,), ix1f, (x,), size_dict) + shape1 = [s for (l, s) in zip(iy, size(y)) if l ∈ ix] + shape2 = [l ∈ ix ? s : 1 for (l, s) in zip(iy, size(y))] + repeat_dims = [l ∈ ix ? 1 : s for (l, s) in zip(iy, size(y))] + # TODO: avoid copy + if ix1f != ix + y1 = similar(x, (shape1...,)) + unary_einsum!(Permutedims(), ix, ix1f, x, y1, true, false) else - x + y1 = x end - newshape = [l ∈ ix ? size_dict[l] : 1 for l in iy] - repeat_dims = [l ∈ ix ? 1 : size_dict[l] for l in iy] - repeat(reshape(res, newshape...), repeat_dims...) + @flatten_addmul! sy * y + sx * repeat(reshape(y1, shape2...), repeat_dims...) end # overhead ~ 0.28us # @benchmark OMEinsum.einsum(Diag(), $(('a', 'a')), $(('a',)), x, $(Dict('a'=>1, 'b'=>1))) setup=(x=randn(1,1)) -function einsum(::Diag, ixs, iy, xs::Tuple{<:AbstractArray}, size_dict::Dict) - ix, x = ixs[1], xs[1] +function unary_einsum!(::Diag, ix, iy, x::AbstractArray, y::AbstractArray, sx, sy) @debug "Diag" ix => iy size.(x) - compactify!(get_output_array((x,), map(y->size_dict[y],iy); has_repeated_indices=false),x,ix, iy) + compactify!(y, x, ix, iy, sx, sy) end -function compactify!(y, x, ix, iy) +function compactify!(y, x, ix, iy, sx, sy) x_in_y_locs = (Int[findfirst(==(x), iy) for x in ix]...,) @assert size(x) == map(loc->size(y, loc), x_in_y_locs) indexer = dynamic_indexer(x_in_y_locs, size(x)) - _compactify!(y, x, indexer) + _compactify!(y, x, indexer, sx, sy) end -function _compactify!(y, x, indexer) +function _compactify!(y, x, indexer, sx, sy) @inbounds for ci in CartesianIndices(y) - y[ci] = x[subindex(indexer, ci.I)] + y[ci] = sy * y[ci] + sx * x[subindex(indexer, ci.I)] end return y end -function duplicate(x, ix, iy, size_dict) - y = get_output_array((x,), map(y->size_dict[y],iy); has_repeated_indices=true) +function duplicate!(y, x, ix, iy, sx, sy) # compute same locs x_in_y_locs = (Int[findfirst(==(l), ix) for l in iy]...,) indexer = dynamic_indexer(x_in_y_locs, size(y)) - _duplicate!(y, x, indexer) + lmul!(sy, y) + _duplicate!(y, x, indexer, sx) end -@noinline function _duplicate!(y, x, indexer) +@noinline function _duplicate!(y, x, indexer, sx) map(CartesianIndices(x)) do ci - @inbounds y[subindex(indexer, ci.I)] = x[ci] + @inbounds y[subindex(indexer, ci.I)] += sx * x[ci] end return y end @@ -159,49 +96,22 @@ end # e.g. 'ij'->'iij', left indices are unique, right are not # overhead ~ 0.29us # @benchmark OMEinsum.einsum(Duplicate(), $((('a', ),)), $(('a','a')), (x,), $(Dict('a'=>1, 'b'=>1))) setup=(x=randn(1)) -function einsum(::Duplicate, ixs, iy, xs::Tuple{<:AbstractArray}, size_dict) - ix, x = ixs[1], xs[1] +function unary_einsum!(::Duplicate, ix, iy, x::AbstractArray, y::AbstractArray, sx, sy) @debug "Duplicate" ix => iy size(x) - duplicate(x, ix, iy, size_dict) + duplicate!(y, x, ix, iy, sx, sy) end # overhead ~ 0.15us # @benchmark OMEinsum.einsum(Permutedims(), $((('a', 'b'),)), $(('b','a')), (x,), $(Dict('a'=>1, 'b'=>1))) setup=(x=randn(1,1)) -function einsum(::Permutedims, ixs, iy, xs::Tuple{<:AbstractArray}, size_dict) - ix, x = ixs[1], xs[1] +function unary_einsum!(::Permutedims, ix, iy, x::AbstractArray, y::AbstractArray, sx, sy) perm = ntuple(i -> findfirst(==(iy[i]), ix)::Int, length(iy)) @debug "Permutedims" ix => iy size(x) perm - return tensorpermute(x, perm) + return tensorpermute!(y, x, perm, sx, sy) end # overhead ~0.04us # @benchmark OMEinsum.einsum(Identity(), $((('a', 'b'),)), $(('a','b')), (x,), $(Dict('a'=>1, 'b'=>1))) setup=(x=randn(1,1)) -function einsum(::Identity, ixs, iy, xs::Tuple{<:AbstractArray}, size_dict) - @debug "Identity" ixs[1] => iy size(xs[1]) - copy(xs[1]) # must copy, otherwise AD may fail! -end - -# for unary operations -# overhead ~ 2.3us -# @benchmark OMEinsum.einsum(DefaultRule(), $((('a', 'a', 'b'),)), $(('c', 'b','a')), (x,), $(Dict('a'=>1, 'b'=>1, 'c'=>1))) setup=(x=randn(1,1,1)) -function einsum(::DefaultRule, ixs, iy, xs::Tuple{<:AbstractArray}, size_dict::Dict{LT}) where LT - ix, x = ixs[1], xs[1] - @debug "DefaultRule unary" ix => iy size(x) - # diag - ix_ = _unique(LT, ix) - x_ = length(ix_) != length(ix) ? einsum(Diag(), (ix,), (ix_...,), (x,), size_dict) : x - # sum - iy_b = _unique(LT, iy) - iy_a = filter(i->i ∈ ix, iy_b) - y_a = if length(ix_) != length(iy_a) - einsum(Sum(), ((ix_...,),), (iy_a...,), (x_,), size_dict) - elseif ix_ != iy_a - einsum(Permutedims(), ((ix_...,),), (iy_a...,), (x_,), size_dict) - else - x_ - end - # repeat - y_b = length(iy_a) != length(iy_b) ? einsum(Repeat(), ((iy_a...,),), (iy_b...,), (y_a,), size_dict) : y_a - # duplicate - length(iy_b) != length(iy) ? einsum(Duplicate(), ((iy_b...,),), iy, (y_b,), size_dict) : y_b -end +function unary_einsum!(::Identity, ix, iy, x::AbstractArray, y::AbstractArray, sx, sy) + @debug "Identity" ix => iy size(x) + @flatten_addmul! sy * y + sx * x # NOTE: copy can not be avoided, otherwise AD may fail! +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 1031c63..e437816 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,41 @@ +macro addmul!(ex) + esc(addmul_impl(ex, false)) +end +macro flatten_addmul!(ex) + esc(addmul_impl(ex, true)) +end + +function addmul_impl(ex::Expr, flatten::Bool) + @assert ex.head === :call && length(ex.args) == 3 + dotadd, ay, bxs = ex.args + @assert dotadd == :+ + @assert ay.head === :call && length(ay.args) == 3 + dotmul, a, y = ay.args + @assert dotmul == :* + @assert bxs.head === :call + dotmul2, b, xs... = bxs.args + @assert dotmul2 == :* + @assert length(xs) > 0 + + added = :(Ref($b)) + for x in xs + added = :($added .* $(flatten ? :(vec($x)) : x)) + end + vy = flatten ? :(vec($y)) : y + quote + if iszero($b) # no need to multiply + $lmul!($a, $vy) + elseif iszero($a) # empty y + $vy .= $added + elseif isone($a) + $vy .+= $added + else # a != 1, a != 0, b != 0 + $vy .= Ref($a) .* $vy .+ $added + end + $y + end +end + """ asarray(x[, parent::AbstractArray]) -> AbstactArray @@ -81,11 +119,10 @@ end `permutedims(A, perm)` with grouped dimensions. """ -function tensorpermute(A::AbstractArray{T,N}, perm) where {T, N} +function tensorpermute!(C::AbstractArray{T, N}, A::AbstractArray{T,N}, perm, sx, sy) where {T, N} @assert N == length(perm) && all(p->1<=p<=N, perm) N == 0 && return copy(A) # group `perm`s - permshape = ntuple(i->size(A, @inbounds perm[i]), N) newshape_slots = fill(-1, N) dk = 1 # the size of dimension-batch @inbounds begin @@ -108,29 +145,54 @@ function tensorpermute(A::AbstractArray{T,N}, perm) where {T, N} newshape = filter(!=(-1), newshape_slots) newperm = sortperm(sortperm(newperm)) A_ = reshape(A, newshape...) - A__ = permutedims(A_, newperm) - return reshape(A__, permshape...) -end - -# reload this function for GPU support! -function _batched_gemm(C1::Char, C2::Char, A::StridedArray{T,3}, B::StridedArray{T2,3}) where {T<:BlasFloat, T2<:BlasFloat} - batched_gemm(C1, C2, A, B) + permed_shape = ntuple(i->size(A_, @inbounds newperm[i]), ndims(A_)) + if iszero(sy) + permutedims!(reshape(C, permed_shape), A_, newperm) + !isone(sx) && lmul!(sx, C) + return C + else + return @flatten_addmul! sy * C + sx * permutedims(A_, newperm) + end end -function _batched_gemm(C1::Char, C2::Char, A::AbstractArray{T,3}, B::AbstractArray{T2,3}) where {T<:BlasFloat, T2<:BlasFloat} - batched_gemm(C1, C2, Array(A), Array(B)) +# new interface for GPU support! +# function _batched_gemm!(C1::Char, C2::Char, alpha, A::StridedArray{T,3}, B::StridedArray{T2,3}, beta, C::StridedArray{T3,3}) where {T<:BlasFloat, T2<:BlasFloat, T3<:BlasFloat} +# batched_gemm!(C1, C2, alpha, A, B, beta, C) +# end +function _batched_gemm!(C1::Char, C2::Char, alpha, A::AbstractArray{T,3}, B::AbstractArray{T2,3}, beta, C::AbstractArray{T3,3}) where {T<:BlasFloat, T2<:BlasFloat,T3<:BlasFloat} + # NOTE: convert alpha and beta to T3, since booleans are not supported by BatchedRoutines + #batched_gemm!(C1, C2, T3(alpha), Array(A), Array(B), T3(beta), C) + batched_gemm!(C1, C2, T3(alpha), A, B, T3(beta), C) end - -function _batched_gemm(C1::Char, C2::Char, A::AbstractArray{T,3}, B::AbstractArray{T2,3}) where {T, T2} - @assert size(A, 3) == size(B, 3) "batch dimension mismatch, got $(size(A,3)) and $(size(B,3))" +function _batched_gemm!(C1::Char, C2::Char, alpha, A::AbstractArray{T,3}, B::AbstractArray{T2,3}, beta, C::AbstractArray{T3,3}) where {T, T2,T3} + @assert size(A, 3) == size(B, 3) == size(C, 3) "batch dimension mismatch, got $(size(A,3)), $(size(B,3)) and $(size(C,3))" @assert C1 === 'N' || C1 === 'T' @assert C2 === 'N' || C2 === 'T' - L = size(A, 3) - C = similar(A, promote_type(T,T2), C1==='N' ? size(A,1) : size(A,2), C2==='N' ? size(B,2) : size(B,1), L) - for l = 1:L + for l = 1:size(A, 3) a = C1 === 'T' ? transpose(view(A,:,:,l)) : view(A,:,:,l) b = C2 === 'T' ? transpose(view(B,:,:,l)) : view(B,:,:,l) - mul!(view(C,:,:,l), a, b) + mul!(view(C,:,:,l), a, b, alpha, beta) end return C end + +# macro addmul!(a, y, b, xs...) +# added = :(Ref(b)) +# for x in xs +# added = :($added .* $x) +# end +# yeval = gensym("y") +# quote +# $yeval = $y +# if iszero($b) # no need to multiply +# $lmul!($a, $yeval) +# elseif iszero($a) # empty y +# $yeval .= $added +# elseif isone($a) +# $yeval .+= $added +# else # a != 1, a != 0, b != 0 +# $yeval .= Ref($a) .* $yeval .+ $added +# end +# $yeval +# end |> esc +# end \ No newline at end of file diff --git a/test/binaryrules.jl b/test/binaryrules.jl index 70fa4c6..f2c6cb3 100644 --- a/test/binaryrules.jl +++ b/test/binaryrules.jl @@ -1,10 +1,10 @@ using OMEinsum, Test -using OMEinsum: SimpleBinaryRule, match_rule +using OMEinsum: SimpleBinaryRule, match_rule, binary_einsum! using Polynomials: Polynomial @testset "analyse binary" begin size_dict = Dict(1=>1, 2=>2, 3=>3, 4=>4, 6=>6, 7=>7, 8=>8) - c1, c2, cy, s1, s2, is, js, ys = OMEinsum.analyze_binary([1,2,3,4,8], [2,6,6,8,4,2], [7,2,1,2,2,6], size_dict) + c1, c2, cy, s1, s2, s3, is, js, ys = OMEinsum.analyze_binary([1,2,3,4,8], [2,6,6,8,4,2], [7,2,1,2,2,6], size_dict) @test c1 == [1,4,8,2] @test c2 == [4,8,6,2] @test cy == [1,6,2] @@ -32,7 +32,8 @@ end rule = match_rule(code) if rule isa SimpleBinaryRule nmatch += 1 - @test einsum(rule, (a, b)) ≈ loop_einsum(code, (a, b), size_dict) + out = OMEinsum.get_output_array((a, b), [size_dict[l] for l in getiyv(code)]; fillzero=false) + @test binary_einsum!(rule, a, b, out, true, false) ≈ loop_einsum(code, (a, b), size_dict) else @test einsum(code, (a, b), size_dict) ≈ loop_einsum(code, (a, b), size_dict) end diff --git a/test/cueinsum.jl b/test/cueinsum.jl index dac134f..9b6d917 100644 --- a/test/cueinsum.jl +++ b/test/cueinsum.jl @@ -6,12 +6,27 @@ using Zygote CUDA.allowscalar(false) +@testset "loop einsum" begin + a = [randn(fill(3, i)...) for i=1:4] + ca = a .|> CuArray + ixs = ((1,2), (2,3)) + xs = (ca[2], ca[2]) + @test loop_einsum!(ixs, (1,3), xs, zeros(3,3) |> CuArray, true, false, OMEinsum.get_size_dict(ixs, xs)) |> Array ≈ a[2]*a[2] + @test loop_einsum!(ixs, (1,3), xs, ones(3,3) |> CuArray, 4.0, 2.0, OMEinsum.get_size_dict(ixs, xs)) |> Array ≈ 4.0 * a[2]*a[2] + 2 * ones(3, 3) + res = 4.0 * a[2]*a[2] + out = 2 * ones(3, 3, 3) + for k = 1:3 + out[:,k,k] .+= res[:,k] + end + @test loop_einsum!(ixs, (1,3,3), xs, ones(3,3,3) |> CuArray, 4.0, 2.0, OMEinsum.get_size_dict(ixs, xs)) |> Array ≈ out +end + @testset "cuda einsum" begin a = [randn(fill(3, i)...) for i=1:4] ca = a .|> CuArray ixs = ((1,2), (2,3)) xs = (ca[2], ca[2]) - @test loop_einsum!(EinCode(ixs, (1,3)), xs, zeros(3,3) |> CuArray, OMEinsum.get_size_dict(ixs, xs)) ≈ ca[2]*ca[2] + @test loop_einsum!(ixs, (1,3), xs, zeros(3,3) |> CuArray, true, false, OMEinsum.get_size_dict(ixs, xs)) ≈ ca[2]*ca[2] for f in [ein"ij,jk->ik", ein"ii->", ein"ijj ->i", ein"ij,ik,il->jkl", ein"ii->i", ein"ijl->i", ein"i->ii", ein"ij,jk,kl->il", ein"ij,ij,ij -> ij"] cins = map(ix->ca[length(ix)], OMEinsum.getixs(f)) ins = map(ix->a[length(ix)], OMEinsum.getixs(f)) @@ -56,7 +71,7 @@ end @test M |> Array ≈ loop_einsum(_code, xs, OMEinsum.get_size_dict(OMEinsum.getixs(_code), xs)) |> Array end -@testset "binary einsum" begin +@testset "unary einsum rules" begin for code in [ ein"ij->", # sum ein"ij->j", # sum @@ -69,10 +84,58 @@ end ein"i->ik", # ~sum ein"->ik", # ~sum ein"illljkk->kijjcc", # general + ] + @info code + D = 2 + xs = [length(ix)==0 ? CUDA.fill(1.2) : CUDA.rand(Float64, fill(D, length(ix))...) for ix in OMEinsum.getixs(code)] + size_dict = Dict(zip(('a', 'b', 'c', 'd', 'e', 'f','i','j','k','l'), ntuple(x->D, 10))) + res = einsum(code, (xs...,), size_dict) + @test Array(res) ≈ loop_einsum(code, (Array.(xs)...,), size_dict) + @test Array(res) ≈ Array(loop_einsum(code, (xs...,), size_dict)) + end +end + +@testset "binary einsum rules" begin + codes = [ # binary ein",->", - ein"ijl,jl->il", - ein"ab,bc->ac", + ein"i,->i", + ein"j,j->", + ein",k->k", + ein"j,jk->k", + ein"j,kj->k", + ein"ij,j->i", + ein"ji,j->i", + ein"i,k->ik", + ein"i,k->ki", + ] + + for (i1, X1) in enumerate([('i', 'j'), ('j', 'i')]) + for (i2, X2) in enumerate([('j', 'k'), ('k', 'j')]) + for (i3, X3) in enumerate([('i', 'k'), ('k', 'i')]) + push!(codes, OMEinsum.StaticEinCode{Char, (X1,X2),X3}()) + end + end + end + for code in copy(codes) + X1, X2 = OMEinsum.getixs(code) + X3 = OMEinsum.getiy(code) + push!(codes, OMEinsum.StaticEinCode{Char, ((X1...,'l'),(X2...,'l')),(X3...,'l')}()) + end + + for code in codes + @info code + D = 2 + xs = [length(ix)==0 ? CUDA.fill(1.2) : CUDA.rand(Float64, fill(D, length(ix))...) for ix in OMEinsum.getixs(code)] + size_dict = Dict(zip(('a', 'b', 'c', 'd', 'e', 'f','i','j','k','l'), ntuple(x->D, 10))) + res = einsum(code, (xs...,), size_dict) + @test Array(res) ≈ loop_einsum(code, (Array.(xs)...,), size_dict) + @test Array(res) ≈ Array(loop_einsum(code, (xs...,), size_dict)) + end +end + +@testset "composite einsum" begin + for code in [ ein"abb,bc->ac", # with diag in ein"ab,bc->acc", # with diag out ein"ab,bce->ac", # with sum in @@ -92,20 +155,6 @@ end end end -@testset "binary rules" begin - for (code, a, b) in [ - (ein"j,j->", randn(10), randn(10)), - (ein"i,->i", randn(10), fill(2.0, ())), - (ein",->", fill(2.0,()), fill(2.0, ())), - (ein"il,kl->ikl", randn(10, 10), randn(10, 10)), - ] - res0 = code(a, b) - res1 = code(CuArray(a), CuArray(b)) - @test res1 isa CuArray - @test res0 ≈ Array(res1) - end -end - @testset "permutedims for high dimensional tensors" begin c = CUDA.rand(4, [rand(1:3) for _ = 2:18]...); @test Array(permutedims(c, 18:-1:1)) ≈ permutedims(Array(c), 18:-1:1) diff --git a/test/einsequence.jl b/test/einsequence.jl index a749aae..0560753 100644 --- a/test/einsequence.jl +++ b/test/einsequence.jl @@ -30,10 +30,20 @@ using OMEinsum: IndexGroup, NestedEinsum, parse_nested, DynamicEinCode, isleaf, abc2 = ein"((ij,jk),km) -> im"(a,b,c) abc3 = ein"ij,jk,km -> im"(a,b,c) + ne = ein"(ij,jk),km -> im" + print(ne) + dne = DynamicNestedEinsum(ne) + print(dne) + args = (dne, dne) + eins = DynamicEinCode([['a', 'b'], ['b', 'c']], ['a', 'c']) + @test DynamicNestedEinsum(args, eins) isa DynamicNestedEinsum + @test abc1 ≈ abc2 ≈ abc3 size_info = Dict('k'=>2) a, b, c, d = randn(2), randn(2,2), randn(2), randn(2) @test ein"((i,ij),i),j->ik"(a, b, c, d; size_info=size_info) ≈ ein"i,ij,i,j->ik"(a, b, c, d; size_info=size_info) + size_dict = Dict([l=>2 for l in "ijkl"]) + @test einsum!(ein"((i,ij),i),j->ik", (a, b, c, d), randn(2, 2), true, false, size_dict) ≈ ein"i,ij,i,j->ik"(a, b, c, d; size_info=size_info) @test getixsv(ein"((i,ij),i),j->ik") == getixsv(ein"i,ij,i,j->ik") == getixsv(DynamicEinCode(ein"i,ij,i,j->ik")) == [['i'], ['i','j'], ['i'], ['j']] @test getiyv(ein"((i,ij),i),j->ik") == getiyv(ein"i,ij,i,j->ik") == getiyv(DynamicEinCode(ein"i,ij,i,j->ik")) == ['i','k'] end diff --git a/test/einsum.jl b/test/einsum.jl index c365808..30d33fc 100644 --- a/test/einsum.jl +++ b/test/einsum.jl @@ -1,8 +1,8 @@ using Test using OMEinsum -using OMEinsum: get_size_dict, Sum, Tr, DefaultRule, Permutedims, Duplicate +using OMEinsum: get_size_dict using SymEngine -using LinearAlgebra: I +using LinearAlgebra: I, tr SymEngine.free_symbols(syms::Union{Real, Complex}) = Basic[] SymEngine.free_symbols(syms::AbstractArray{T}) where {T<:Union{Real, Complex}} = Basic[] @@ -24,6 +24,58 @@ Base.:≈(x::AbstractArray, y::AbstractArray{<:Basic}; atol=1e-8) = _basic_appro Base.:≈(x::AbstractArray{<:Basic}, y::AbstractArray{<:Basic}; atol=1e-8) = _basic_approx(x, y, atol=atol) Base.Complex{T}(a::Basic) where T = T(real(a)) + im*T(imag(a)) +@testset "unary einsum" begin + size_dict = Dict(1=>3,2=>3,3=>3,4=>4,5=>5) + ix = (1,2,3,3,4) + x = randn(3,3,3,3,4) + iy = (3,5,1,1,2,5) + y = randn(3,5,3,3,3,5) + # Diag, Sum, Repeat, Duplicate + @test einsum!((ix,), iy, (x,), y, true, false, size_dict) ≈ loop_einsum(EinCode((ix,), iy), (x,), size_dict) + ix = (1,2,3,4) + x = randn(3,3,3,4) + iy = (4,3,1,2) + y = randn(4,3,3,3) + # Permutedims + @test einsum!((ix,), iy, (x,), y, true, false, size_dict) ≈ loop_einsum(EinCode((ix,), iy), (x,), size_dict) + # None + ix = (1,2,3,4) + x = randn(3,3,3,4) + iy = (1,2,3,4) + y = randn(3,3,3,4) + @test einsum!((ix,), iy, (x,), y, true, false, size_dict) ≈ loop_einsum(EinCode((ix,), iy), (x,), size_dict) + # tr + ix = (1,1) + x = randn(3,3) + iy = () + y = fill(1.0) + @test einsum!((ix,), iy, (x,), y, true, false, size_dict)[] ≈ tr(x) +end + +@testset "binary einsum" begin + size_dict = Dict(1=>3,2=>3,3=>3,4=>4,5=>5) + ix = (1,2,3,3,4) + x = randn(3,3,3,3,4) + iy = (3,5,1,1,2,5) + y = randn(3,5,3,3,3,5) + iz = (1,2,3,4,5,5) + z = randn(3,3,3,4,5,5) + @test einsum!((ix, iy), iz, (x, y), z, true, false, size_dict) ≈ loop_einsum(EinCode((ix, iy), iz), (x, y), size_dict) + @test einsum!((ix, iy), iz, (x, y), copy(z), 5.0, 3.0, size_dict) ≈ loop_einsum!((ix, iy), iz, (x, y), copy(z), 5.0, 3.0, size_dict) + @test einsum!((ix, iy), iz, (x, y), copy(z), 5.0, 1.0, size_dict) ≈ loop_einsum!((ix, iy), iz, (x, y), copy(z), 5.0, 1.0, size_dict) +end + +@testset "nary, einsum" begin + size_dict = Dict(1=>3,2=>3,3=>3,4=>4,5=>5) + ix = (1,2,3,3,4) + x = randn(3,3,3,3,4) + iy = (3,5,1) + y = randn(3,5,3) + iz = (1,2,3,4,5,5) + z = randn(3,3,3,4,5,5) + @test einsum!((ix, iy, iz), (), (x, y, z), fill(1.0), true, false, size_dict) ≈ loop_einsum(EinCode((ix, iy, iz), ()), (x, y, z), size_dict) +end + @testset "get output array" begin xs = (randn(4,4), randn(3)) @test OMEinsum.get_output_array(xs, (5, 5)) isa Array{Float64} @@ -207,39 +259,6 @@ end @test_throws DimensionMismatch einsum(ein"ij,jk -> ik", (rand(2,3), rand(2,2))) end -@testset "dispatched" begin - # index-sum - a = rand(2,2,5) - ixs, xs = ((1,2,3),), (a,) - @test einsum(Sum(), ixs,(1,2),xs, get_size_dict(ixs, xs)) ≈ sum(a, dims=3) - a = rand(5,5) - @test einsum(Tr(), ((1,1),),(), (a,), get_size_dict(((1,1),), (a,)))[] ≈ sum(a[i,i] for i in 1:5) - t = rand(5,5,5,5) - a = rand(5,5) - size_dict = Dict(zip((1,2,3,4,2,3), ((size(t)..., size(a)...)))) - - OMEinsum.allow_loops(false) - @test_throws ErrorException loop_einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) - OMEinsum.allow_loops(true) - - ta = loop_einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) - @test einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) ≈ ta - @test einsum(DefaultRule(), ((1,2,3,4), (2,3)), (1,4), (t,a), size_dict) ≈ ta - - # index-sum - a = Basic.(rand(2,2,5)) - ixs, xs = ((1,2,3),), (a,) - @test einsum(Sum(), ixs,(1,2),xs, get_size_dict(ixs, xs)) ≈ sum(a, dims=3) - a = Basic.(rand(5,5)) - @test isapprox(einsum(Tr(), ((1,1),),(), (a,), get_size_dict(((1,1),), (a,)))[], sum(a[i,i] for i in 1:5), rtol=1e-8) - t = Basic.(rand(5,5,5,5)) - a = Basic.(rand(5,5)) - size_dict = Dict(zip((1,2,3,4,2,3), ((size(t)..., size(a)...)))) - ta = loop_einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) - @test einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) ≈ ta - @test einsum(DefaultRule(), ((1,2,3,4), (2,3)), (1,4), (t,a), size_dict) ≈ ta -end - @testset "isbatchmul" begin for (ixs, iy) in [(((1,2), (2,3)), (1,3)), (((1,2,3), (2,3)), (1,3)), (((7,1,2,3), (2,4,3,7)), (1,4,3)), @@ -250,31 +269,12 @@ end end end -@testset "duplicate" begin - ix = (1,2,3) - iy = (3,2,1,1,2) - size_dict = Dict(1=>3,2=>3,3=>3) - x = randn(3,3,3) - @test OMEinsum.duplicate(x, ix, iy, size_dict) ≈ OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) - @test OMEinsum.einsum(Duplicate(), (ix,), iy, (x,), size_dict) ≈ OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) -end - @testset "issue 136" begin @test EinCode(((1,2,3),(2,)),(1,3))(ones(2,2,1), ones(2)) == reshape([2,2.0], 2, 1) @test EinCode(((1,2,3),(2,)),(1,3))(ones(2,2,0), ones(2)) == reshape(zeros(0), 2, 0) end @testset "fix rule cc,cb->bc" begin - @test OMEinsum.match_rule_binary([3], [1], [1,3]) isa OMEinsum.SimpleBinaryRule - @test OMEinsum.match_rule_binary([1,3], [2,3], [1,2,3]) isa OMEinsum.SimpleBinaryRule - @test OMEinsum.match_rule_binary([3], [3], [3,3]) isa OMEinsum.DefaultRule - @test OMEinsum.match_rule_binary([3], [3, 3], [3]) isa OMEinsum.DefaultRule - @test OMEinsum.match_rule_binary([3, 3], [3], [3]) isa OMEinsum.DefaultRule - @test OMEinsum.match_rule_binary([3,3], [3, 3], [3,3]) isa OMEinsum.DefaultRule - @test OMEinsum.match_rule_binary([3, 3], [3,2], [2,3]) isa OMEinsum.DefaultRule - @test OMEinsum.match_rule_binary([3,1], [1, 3], [3,3]) isa OMEinsum.DefaultRule - @test OMEinsum.match_rule_binary([1,3], [3, 3], [1,3]) isa OMEinsum.DefaultRule - @test OMEinsum.match_rule_binary([3,3], [3], [3,3]) isa OMEinsum.DefaultRule size_dict = Dict('a'=>2,'b'=>2,'c'=>2) for code in [ein"c,c->cc", ein"c,cc->c", ein"cc,c->cc", ein"cc,cc->cc", ein"cc,cb->bc", ein"cb,bc->cc", ein"ac,cc->ac"] @info code @@ -282,4 +282,26 @@ end b = randn(fill(2, length(getixsv(code)[2]))...) @test code(a, b) ≈ OMEinsum.loop_einsum(code, (a,b), size_dict) end -end \ No newline at end of file +end + +# patch for SymEngine +Base.promote_rule(::Type{Bool}, ::Type{Basic}) = Basic +@testset "allow loops" begin + t = rand(5,5,5,5) + a = rand(5,5) + size_dict = Dict(zip((1,2,3,4,2,3), ((size(t)..., size(a)...)))) + + OMEinsum.allow_loops(false) + @test_throws ErrorException loop_einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) + OMEinsum.allow_loops(true) + + ta = loop_einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) + @test einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) ≈ ta + + # index-sum + t = Basic.(rand(5,5,5,5)) + a = Basic.(rand(5,5)) + size_dict = Dict(zip((1,2,3,4,2,3), ((size(t)..., size(a)...)))) + ta = loop_einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) + @test einsum(EinCode(((1,2,3,4), (2,3)), (1,4)), (t,a), size_dict) ≈ ta +end diff --git a/test/EinRule.jl b/test/matchrule.jl similarity index 78% rename from test/EinRule.jl rename to test/matchrule.jl index d788317..0382579 100644 --- a/test/EinRule.jl +++ b/test/matchrule.jl @@ -1,7 +1,18 @@ -using Test, OMEinsum -using OMEinsum: match_rule, Sum, Tr, DefaultRule, Diag, Duplicate, - Permutedims, nopermute, - Identity, SimpleBinaryRule +using OMEinsum, Test +using OMEinsum: match_rule, DefaultRule, Diag, Sum, Tr, Permutedims, Duplicate, SimpleBinaryRule, Identity, nopermute + +@testset "fix rule cc,cb->bc" begin + @test OMEinsum.match_rule_binary([3], [1], [1,3]) isa OMEinsum.SimpleBinaryRule + @test OMEinsum.match_rule_binary([1,3], [2,3], [1,2,3]) isa OMEinsum.SimpleBinaryRule + @test OMEinsum.match_rule_binary([3], [3], [3,3]) isa OMEinsum.DefaultRule + @test OMEinsum.match_rule_binary([3], [3, 3], [3]) isa OMEinsum.DefaultRule + @test OMEinsum.match_rule_binary([3, 3], [3], [3]) isa OMEinsum.DefaultRule + @test OMEinsum.match_rule_binary([3,3], [3, 3], [3,3]) isa OMEinsum.DefaultRule + @test OMEinsum.match_rule_binary([3, 3], [3,2], [2,3]) isa OMEinsum.DefaultRule + @test OMEinsum.match_rule_binary([3,1], [1, 3], [3,3]) isa OMEinsum.DefaultRule + @test OMEinsum.match_rule_binary([1,3], [3, 3], [1,3]) isa OMEinsum.DefaultRule + @test OMEinsum.match_rule_binary([3,3], [3], [3,3]) isa OMEinsum.DefaultRule +end @testset "match rule" begin ixs = ((1,2), (2,3)) @@ -103,4 +114,4 @@ end @testset "match_rule eye candies" begin @test match_rule(ein"ij,jk,kl->il") == DefaultRule() -end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9c6cec5..10a6aea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,21 +5,56 @@ using CUDA import Documenter -@testset "OMEinsum.jl" begin +@testset "Core" begin include("Core.jl") - include("EinRule.jl") +end + +@testset "match rule" begin + include("matchrule.jl") +end + +@testset "unary rules" begin + include("unaryrules.jl") +end + +@testset "binary rules" begin include("binaryrules.jl") +end + +@testset "utils" begin include("utils.jl") +end + +@testset "einsum" begin include("einsum.jl") +end + +@testset "cuda" begin if CUDA.functional() include("cueinsum.jl") end +end + +@testset "autodiff" begin include("autodiff.jl") +end + +@testset "einsequence" begin include("einsequence.jl") +end + +@testset "slicing" begin include("slicing.jl") +end + +@testset "interfaces" begin include("interfaces.jl") +end +@testset "contraction order" begin include("contractionorder.jl") +end +@testset "docstring" begin Documenter.doctest(OMEinsum; manual=false) end diff --git a/test/unaryrules.jl b/test/unaryrules.jl new file mode 100644 index 0000000..a1d8bc3 --- /dev/null +++ b/test/unaryrules.jl @@ -0,0 +1,56 @@ +using OMEinsum, Test +using OMEinsum: unary_einsum!, Duplicate, Sum, Tr, Permutedims, Repeat, Diag, Identity +using SymEngine: Basic + +@testset "Duplicate" begin + ix = (1,2,3) + iy = (3,2,1,1,2) + size_dict = Dict(1=>3,2=>3,3=>3) + x = randn(3,3,3) + y = randn(3,3,3,3,3) + @test OMEinsum.duplicate!(y, x, ix, iy, true, false) ≈ OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) + @test unary_einsum!(Duplicate(), ix, iy, x, y, true, false) ≈ OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) +end + +@testset "Diag" begin + ix = (3,2,1,1,2) + iy = (1,2,3) + size_dict = Dict(1=>3,2=>3,3=>3) + x = randn(3,3,3,3,3) + y = randn(3,3,3) + @test unary_einsum!(Diag(), ix, iy, x, y, true, false) ≈ OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) +end + +@testset "Repeat" begin + ix = (1,2,3) + iy = (3,4,2,1) + size_dict = Dict(1=>3,2=>3,3=>3,4=>5) + x = randn(3,3,3) + y = randn(3,5,3,3) + @test unary_einsum!(Repeat(), ix, iy, x, y, true, false) ≈ OMEinsum.loop_einsum(EinCode((ix,),iy), (x,), size_dict) +end + +@testset "Tr" begin + a = rand(5,5) + @test unary_einsum!(Tr(), (1,1),(), a, fill(1.0), true, false)[] ≈ sum(a[i,i] for i in 1:5) + a = Basic.(rand(5,5)) + @test isapprox(unary_einsum!(Tr(), (1,1),(), a, fill(Basic(0)), 1, 0)[], sum(a[i,i] for i in 1:5), rtol=1e-8) +end + +@testset "Permutedims" begin + a = rand(5,5,3) + @test unary_einsum!(Permutedims(), (1,2,3), (2,3,1), a, zeros(5, 3, 5), true, false) ≈ permutedims(a, (2,3,1)) +end + +@testset "Identity" begin + a = rand(5,5,3) + @test unary_einsum!(Identity(), (1,2,3), (1,2,3), a, ones(5, 5, 3), 2.0, 3.0) ≈ 3 .+ 2a +end + +@testset "Sum" begin + # index-sum + a = rand(2,2,5) + @test unary_einsum!(Sum(), (1, 2, 3), (1,2), a, zeros(2, 2), true, false) ≈ sum(a, dims=3) + a = Basic.(rand(1:100, 2,2,5)) + @test unary_einsum!(Sum(), (1, 2, 3) ,(1,2), a, zeros(Basic, 2, 2), 1, 0) == dropdims(sum(a; dims=3); dims=3) +end \ No newline at end of file diff --git a/test/utils.jl b/test/utils.jl index 49c10bb..ab6ba73 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,6 @@ using OMEinsum: _unique +using OMEinsum +using Test @testset "utils" begin @test _unique(Int,(1,2,3,3,)) == [1,2,3] @@ -18,8 +20,8 @@ end @testset "tensorpermute" begin a = randn(100, 100) - @test OMEinsum.tensorpermute(a, [1,2]) == a - @test OMEinsum.tensorpermute(a, (2,1)) == transpose(a) + @test OMEinsum.tensorpermute!(zero(a), a, [1,2], true, false) == a + @test OMEinsum.tensorpermute!(zero(a), a, (2,1), true, false) == transpose(a) end @testset "align_types" begin @@ -37,7 +39,17 @@ end for C2 in ['N', 'T'] A_ = Array{Any}(A) B_ = Array{Any}(B) - @test OMEinsum._batched_gemm(C1, C2, A, B) ≈ OMEinsum._batched_gemm(C1, C2, A_, B_) + @test OMEinsum._batched_gemm!(C1, C2, true, A, B, false, zeros(10, 10, 10)) ≈ OMEinsum._batched_gemm!(C1, C2, true, A_, B_, false, zeros(10, 10, 10)) end end end + +@testset "addmul!" begin + x = randn(10, 10) + y = randn(10, 10) + z = randn(10, 10) + for a in [0.0, 1.0, 4.0], b in [0.0, 1.0, 4.0] + @test (o = copy(x); OMEinsum.@addmul! a * o + b * y * z) ≈ a .* x .+ b .* y .* z + @test (o = copy(x); OMEinsum.@flatten_addmul! a * o + b * y * z) ≈ a .* x .+ b .* y .* z + end +end \ No newline at end of file