From f7ad72af079b62e701b0742aa237de0931d34988 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Sat, 13 Jan 2024 16:23:37 +0000 Subject: [PATCH] fix bugs --- README.md | 14 ++++++++ ext/CUDAExt.jl | 54 ++++++++++++++++------------ src/binaryrules.jl | 12 +++---- src/einsequence.jl | 9 +++++ test/cueinsum.jl | 85 +++++++++++++++++++++++++++++++++++---------- test/einsequence.jl | 2 ++ 6 files changed, 129 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index c0f748c..af076a4 100644 --- a/README.md +++ b/README.md @@ -224,6 +224,20 @@ julia> gradient(x->optcode(x,s,s,s,s,s,s,s,s,s)[], s)[1] |> sum ``` This tells us that even if we allow duplicates on one vertex, there are no 3-colourings for the peterson graph. +## Comparison with other packages +Similar packages include: +- [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl) and [TensorKit.jl](https://github.com/Jutho/TensorKit.jl) +- [ITensors.jl](https://github.com/ITensor/ITensors.jl) + +Comparing with the above packages, `OMEinsum` is optimized over large scale tensor network (or einsum, sum-product network) contraction. Its main advantages are: +- `OMEinsum` has better support to very high dimensional tensor networks and their contraction order. +- `OMEinsum` allows an index to appear multiple times. +- `OMEinsum` has well tested generic element type support. + +However, `OMEinsum` also has some disadvantages: +- `OMEinsum` does not support good quantum numbers. +- `OMEinsum` has less optimization on small scale problems. + ## Contribute Suggestions and Comments in the _Issues_ are welcome. diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 7aa00bf..2734eeb 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -1,6 +1,6 @@ module CUDAExt -import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm!, asscalar, @addmul! +import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm!, asscalar, @flatten_addmul! using OMEinsum: EinArray, Diag, Repeat, Duplicate, DefaultRule, EinCode, DynamicEinCode, StaticEinCode, NestedEinsum, SimpleBinaryRule, match_rule, loop_einsum, getiy, getixs, _unique, einarray, align_eltypes, siblings, isleaf, tensorindex, _safe_set, rootcode import OMEinsum using LinearAlgebra @@ -29,18 +29,11 @@ end CUDA.cudaconvert(A::EinArray{T}) where T = EinArray{T}(cudaconvert.(A.xs), A.x_indexers, A.y_indexer, A.size, A.ICIS, A.OCIS) CUDA.cu(A::EinArray{T}) where T = EinArray{T}(cu.(A.xs), A.x_indexers, A.y_indexer, A.size, A.ICIS, A.OCIS) -for TP in [:Diag, :Repeat, :Duplicate, :DefaultRule] - @eval function OMEinsum.unary_einsum!(::$TP, ixs, iy, xs::Tuple{<:CUDAArrayTypes}, y::CUDAArrayTypes, sx, sy, size_dict::Dict{LT}) where LT - @debug "cueinsum fallback to loop_einsum" rule ixs => iy size.(xs) - loop_einsum!(ixs, iy, xs, y, sx, sy, size_dict) - end -end - -function OMEinsum.binary_einsum!(::SimpleBinaryRule{('j',), ('j',), ()}, xs::NTuple{2, CUDAArrayTypes}, y::CUDAArrayTypes, sx, sy, size_dict::Dict) - if iszero(sy) - CUDA.@allowscalar y[] = sx * reshape(xs[1],1,:) * xs[2] - else - CUDA.@allowscalar y[] = sy * y[] + sx * reshape(xs[1],1,:) * xs[2] +for TP in [:Diag, :Repeat, :Duplicate] + @eval function OMEinsum.unary_einsum!(::$TP, ix, iy, x::CUDAArrayTypes, y::CUDAArrayTypes, sx, sy) + @debug "cueinsum fallback to loop_einsum" rule ix => iy size(x) + size_dict = OMEinsum.get_size_dict((ix, iy), (x, y)) + loop_einsum!((ix,), iy, (x,), y, sx, sy, size_dict) end end @@ -51,24 +44,25 @@ function loop_einsum!(ixs0, iy0, ixs = (Tuple.(ixs0)...,) iy_ = _unique(LT,iy) NO = length(iy_) - if iszero(sy) - fill!(y, zero(T)) - else - lmul!(sy, y) - end - A = einarray(Val(ixs), Val(iy), xs, size_dict) + A = einarray(Val(ixs), Val((iy_...,)), xs, size_dict) if NO == length(iy) raw_ = similar(y, (fill(1, ndims(A)-NO)...,size(y)...,)) + fill!(raw_, zero(T)) Base.mapreducedim!(x->x, +, raw_, A) if ndims(A)-NO > 0 # fix 1.7 compatibility raw = dropdims(raw_, dims=(1:ndims(A)-NO...,)) else raw = raw_ end - return @addmul! sy * y + sx * raw + return @flatten_addmul! sy * y + sx * raw else - y_ = CUDA.zeros(T, size(A)[end-NO+1:end]...) - y_ = reshape(y_, fill(1, ndims(A)-NO)...,size(y_)...) + if iszero(sy) + fill!(y, zero(T)) + else + lmul!(sy, y) + end + y_ = similar(y, (fill(1, ndims(A)-NO)...,[size_dict[l] for l in iy_]...)) + fill!(y_, zero(T)) raw = Base.mapreducedim!(x->x, +, y_, A) if ndims(A)-NO > 0 # fix 1.7 compatibility raw = dropdims(raw, dims=(1:ndims(A)-NO...,)) @@ -91,7 +85,7 @@ function expanddims!(::Val{ixs}, ::Val{iy}, x, y, sx) where {ixs,iy} i = (blockIdx().x-1) * blockDim().x + threadIdx().x i > length(x) && return nothing @inbounds yi = expandind(Val{ixs}(), Val{iy}(), CIS[i].I) - @inbounds y[CartesianIndex(yi)] = sx * x[i] + @inbounds y[CartesianIndex(yi)] += sx * x[i] nothing end @cuda(blocks=nblocks, threads=nthreads, kernel(y, x)) @@ -118,6 +112,20 @@ function einsum!(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayType return res end +function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), size_dict::Dict; active_free=false) + # do not use map because the static overhead is too large + # do not use `setindex!` because we need to make the AD work + mxs = Vector{AbstractArray}(undef, length(siblings(neinsum))) + for (i, arg) in enumerate(siblings(neinsum)) + mxs = _safe_set(mxs, i, isleaf(arg) ? xs[tensorindex(arg)] : einsum(arg, xs, size_dict; active_free=active_free)) + end + res = einsum(rootcode(neinsum), (mxs...,), size_dict) + active_free && for mx in mxs # free CuArray aggressively. + CUDA.unsafe_free!(mx) + end + return res +end + @info("OMEinsum loaded the CUDA module successfully") end \ No newline at end of file diff --git a/src/binaryrules.jl b/src/binaryrules.jl index eeb921a..e4f5756 100644 --- a/src/binaryrules.jl +++ b/src/binaryrules.jl @@ -30,14 +30,14 @@ end # S = N # T = N function binary_einsum!(::SimpleBinaryRule{('j',), ('j',), ()}, x1, x2, y, sx, sy) - @addmul! sy * y + sx * Ref(transpose(x1) * x2) + @addmul! sy * y + sx * (reshape(x1, 1, length(x1)) * x2) end # ,k->k : 001 # S = N # T = N @inline function binary_einsum!(::SimpleBinaryRule{(), ('k',), ('k',)}, x1, x2, y, sx, sy) - binary_einsum!(SimpleBinaryRule{('i',),(),('i',)}(), x2, x1, y, sx, sy) + @addmul! sy * y + sx * Ref(asscalar(x1)) * x2 end # j,jk->k : 011 @@ -64,10 +64,10 @@ end # S = N^2 # T = N^2 function binary_einsum!(::SimpleBinaryRule{('i',), ('k',), ('i','k')}, x1, x2, y, sx, sy) - @addmul! sy * y + sx * x1 * transpose(x2) + @addmul! sy * y + sx * x1 * reshape(x2, 1, length(x2)) end @inline function binary_einsum!(::SimpleBinaryRule{('i',), ('k',),('k','i')}, x1, x2, y, sx, sy) - @addmul! sy * y + sx * transpose(x1) * x2 + @addmul! sy * y + sx * reshape(x1, 1, length(x1)) * x2 end # 000 @@ -77,12 +77,12 @@ end # 100 function binary_einsum!(::SimpleBinaryRule{('i','l'),('l',), ('i','l')}, x1, x2, y, sx, sy) - @addmul! sy * y + sx * x1 * transpose(x2) + @addmul! sy * y + sx * x1 * reshape(x2, 1, length(x2)) end # 001 @inline function binary_einsum!(::SimpleBinaryRule{('l',), ('k','l'), ('k','l')}, x1, x2, y, sx, sy) - binary_einsum!(SimpleBinaryRule{('i','l'),('l',),('i','l')}(), x2, x1, y, sx, sy) + @addmul! sy * y + sx * reshape(x1, 1, length(x1)) * x2 end # 010 diff --git a/src/einsequence.jl b/src/einsequence.jl index aeb875e..ad87c64 100644 --- a/src/einsequence.jl +++ b/src/einsequence.jl @@ -276,6 +276,15 @@ function einsum!(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,AbstractArray end return einsum!(rootcode(neinsum), (mxs...,), y, sx, sy, size_dict) end +function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), size_dict::Dict) + # do not use map because the static overhead is too large + # do not use `setindex!` because we need to make the AD work + mxs = Vector{AbstractArray}(undef, length(siblings(neinsum))) + for (i, arg) in enumerate(siblings(neinsum)) + mxs = _safe_set(mxs, i, isleaf(arg) ? xs[tensorindex(arg)] : einsum(arg, xs, size_dict)) + end + return einsum(rootcode(neinsum), (mxs...,), size_dict) +end _safe_set(lst, i, y) = (lst[i] = y; lst) diff --git a/test/cueinsum.jl b/test/cueinsum.jl index dac134f..9b6d917 100644 --- a/test/cueinsum.jl +++ b/test/cueinsum.jl @@ -6,12 +6,27 @@ using Zygote CUDA.allowscalar(false) +@testset "loop einsum" begin + a = [randn(fill(3, i)...) for i=1:4] + ca = a .|> CuArray + ixs = ((1,2), (2,3)) + xs = (ca[2], ca[2]) + @test loop_einsum!(ixs, (1,3), xs, zeros(3,3) |> CuArray, true, false, OMEinsum.get_size_dict(ixs, xs)) |> Array ≈ a[2]*a[2] + @test loop_einsum!(ixs, (1,3), xs, ones(3,3) |> CuArray, 4.0, 2.0, OMEinsum.get_size_dict(ixs, xs)) |> Array ≈ 4.0 * a[2]*a[2] + 2 * ones(3, 3) + res = 4.0 * a[2]*a[2] + out = 2 * ones(3, 3, 3) + for k = 1:3 + out[:,k,k] .+= res[:,k] + end + @test loop_einsum!(ixs, (1,3,3), xs, ones(3,3,3) |> CuArray, 4.0, 2.0, OMEinsum.get_size_dict(ixs, xs)) |> Array ≈ out +end + @testset "cuda einsum" begin a = [randn(fill(3, i)...) for i=1:4] ca = a .|> CuArray ixs = ((1,2), (2,3)) xs = (ca[2], ca[2]) - @test loop_einsum!(EinCode(ixs, (1,3)), xs, zeros(3,3) |> CuArray, OMEinsum.get_size_dict(ixs, xs)) ≈ ca[2]*ca[2] + @test loop_einsum!(ixs, (1,3), xs, zeros(3,3) |> CuArray, true, false, OMEinsum.get_size_dict(ixs, xs)) ≈ ca[2]*ca[2] for f in [ein"ij,jk->ik", ein"ii->", ein"ijj ->i", ein"ij,ik,il->jkl", ein"ii->i", ein"ijl->i", ein"i->ii", ein"ij,jk,kl->il", ein"ij,ij,ij -> ij"] cins = map(ix->ca[length(ix)], OMEinsum.getixs(f)) ins = map(ix->a[length(ix)], OMEinsum.getixs(f)) @@ -56,7 +71,7 @@ end @test M |> Array ≈ loop_einsum(_code, xs, OMEinsum.get_size_dict(OMEinsum.getixs(_code), xs)) |> Array end -@testset "binary einsum" begin +@testset "unary einsum rules" begin for code in [ ein"ij->", # sum ein"ij->j", # sum @@ -69,10 +84,58 @@ end ein"i->ik", # ~sum ein"->ik", # ~sum ein"illljkk->kijjcc", # general + ] + @info code + D = 2 + xs = [length(ix)==0 ? CUDA.fill(1.2) : CUDA.rand(Float64, fill(D, length(ix))...) for ix in OMEinsum.getixs(code)] + size_dict = Dict(zip(('a', 'b', 'c', 'd', 'e', 'f','i','j','k','l'), ntuple(x->D, 10))) + res = einsum(code, (xs...,), size_dict) + @test Array(res) ≈ loop_einsum(code, (Array.(xs)...,), size_dict) + @test Array(res) ≈ Array(loop_einsum(code, (xs...,), size_dict)) + end +end + +@testset "binary einsum rules" begin + codes = [ # binary ein",->", - ein"ijl,jl->il", - ein"ab,bc->ac", + ein"i,->i", + ein"j,j->", + ein",k->k", + ein"j,jk->k", + ein"j,kj->k", + ein"ij,j->i", + ein"ji,j->i", + ein"i,k->ik", + ein"i,k->ki", + ] + + for (i1, X1) in enumerate([('i', 'j'), ('j', 'i')]) + for (i2, X2) in enumerate([('j', 'k'), ('k', 'j')]) + for (i3, X3) in enumerate([('i', 'k'), ('k', 'i')]) + push!(codes, OMEinsum.StaticEinCode{Char, (X1,X2),X3}()) + end + end + end + for code in copy(codes) + X1, X2 = OMEinsum.getixs(code) + X3 = OMEinsum.getiy(code) + push!(codes, OMEinsum.StaticEinCode{Char, ((X1...,'l'),(X2...,'l')),(X3...,'l')}()) + end + + for code in codes + @info code + D = 2 + xs = [length(ix)==0 ? CUDA.fill(1.2) : CUDA.rand(Float64, fill(D, length(ix))...) for ix in OMEinsum.getixs(code)] + size_dict = Dict(zip(('a', 'b', 'c', 'd', 'e', 'f','i','j','k','l'), ntuple(x->D, 10))) + res = einsum(code, (xs...,), size_dict) + @test Array(res) ≈ loop_einsum(code, (Array.(xs)...,), size_dict) + @test Array(res) ≈ Array(loop_einsum(code, (xs...,), size_dict)) + end +end + +@testset "composite einsum" begin + for code in [ ein"abb,bc->ac", # with diag in ein"ab,bc->acc", # with diag out ein"ab,bce->ac", # with sum in @@ -92,20 +155,6 @@ end end end -@testset "binary rules" begin - for (code, a, b) in [ - (ein"j,j->", randn(10), randn(10)), - (ein"i,->i", randn(10), fill(2.0, ())), - (ein",->", fill(2.0,()), fill(2.0, ())), - (ein"il,kl->ikl", randn(10, 10), randn(10, 10)), - ] - res0 = code(a, b) - res1 = code(CuArray(a), CuArray(b)) - @test res1 isa CuArray - @test res0 ≈ Array(res1) - end -end - @testset "permutedims for high dimensional tensors" begin c = CUDA.rand(4, [rand(1:3) for _ = 2:18]...); @test Array(permutedims(c, 18:-1:1)) ≈ permutedims(Array(c), 18:-1:1) diff --git a/test/einsequence.jl b/test/einsequence.jl index a749aae..20f43ec 100644 --- a/test/einsequence.jl +++ b/test/einsequence.jl @@ -34,6 +34,8 @@ using OMEinsum: IndexGroup, NestedEinsum, parse_nested, DynamicEinCode, isleaf, size_info = Dict('k'=>2) a, b, c, d = randn(2), randn(2,2), randn(2), randn(2) @test ein"((i,ij),i),j->ik"(a, b, c, d; size_info=size_info) ≈ ein"i,ij,i,j->ik"(a, b, c, d; size_info=size_info) + size_dict = Dict([l=>2 for l in "ijkl"]) + @test einsum!(ein"((i,ij),i),j->ik", (a, b, c, d), randn(2, 2), true, false, size_dict) ≈ ein"i,ij,i,j->ik"(a, b, c, d; size_info=size_info) @test getixsv(ein"((i,ij),i),j->ik") == getixsv(ein"i,ij,i,j->ik") == getixsv(DynamicEinCode(ein"i,ij,i,j->ik")) == [['i'], ['i','j'], ['i'], ['j']] @test getiyv(ein"((i,ij),i),j->ik") == getiyv(ein"i,ij,i,j->ik") == getiyv(DynamicEinCode(ein"i,ij,i,j->ik")) == ['i','k'] end