Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Jan 13, 2024
1 parent 5c8d786 commit f7ad72a
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 47 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,20 @@ julia> gradient(x->optcode(x,s,s,s,s,s,s,s,s,s)[], s)[1] |> sum
```
This tells us that even if we allow duplicates on one vertex, there are no 3-colourings for the peterson graph.

## Comparison with other packages
Similar packages include:
- [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl) and [TensorKit.jl](https://github.com/Jutho/TensorKit.jl)
- [ITensors.jl](https://github.com/ITensor/ITensors.jl)

Comparing with the above packages, `OMEinsum` is optimized over large scale tensor network (or einsum, sum-product network) contraction. Its main advantages are:
- `OMEinsum` has better support to very high dimensional tensor networks and their contraction order.
- `OMEinsum` allows an index to appear multiple times.
- `OMEinsum` has well tested generic element type support.

However, `OMEinsum` also has some disadvantages:
- `OMEinsum` does not support good quantum numbers.
- `OMEinsum` has less optimization on small scale problems.

## Contribute

Suggestions and Comments in the _Issues_ are welcome.
Expand Down
54 changes: 31 additions & 23 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module CUDAExt

import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm!, asscalar, @addmul!
import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm!, asscalar, @flatten_addmul!
using OMEinsum: EinArray, Diag, Repeat, Duplicate, DefaultRule, EinCode, DynamicEinCode, StaticEinCode, NestedEinsum, SimpleBinaryRule, match_rule, loop_einsum, getiy, getixs, _unique, einarray, align_eltypes, siblings, isleaf, tensorindex, _safe_set, rootcode
import OMEinsum
using LinearAlgebra
Expand Down Expand Up @@ -29,18 +29,11 @@ end
CUDA.cudaconvert(A::EinArray{T}) where T = EinArray{T}(cudaconvert.(A.xs), A.x_indexers, A.y_indexer, A.size, A.ICIS, A.OCIS)
CUDA.cu(A::EinArray{T}) where T = EinArray{T}(cu.(A.xs), A.x_indexers, A.y_indexer, A.size, A.ICIS, A.OCIS)

for TP in [:Diag, :Repeat, :Duplicate, :DefaultRule]
@eval function OMEinsum.unary_einsum!(::$TP, ixs, iy, xs::Tuple{<:CUDAArrayTypes}, y::CUDAArrayTypes, sx, sy, size_dict::Dict{LT}) where LT
@debug "cueinsum fallback to loop_einsum" rule ixs => iy size.(xs)
loop_einsum!(ixs, iy, xs, y, sx, sy, size_dict)
end
end

function OMEinsum.binary_einsum!(::SimpleBinaryRule{('j',), ('j',), ()}, xs::NTuple{2, CUDAArrayTypes}, y::CUDAArrayTypes, sx, sy, size_dict::Dict)
if iszero(sy)
CUDA.@allowscalar y[] = sx * reshape(xs[1],1,:) * xs[2]
else
CUDA.@allowscalar y[] = sy * y[] + sx * reshape(xs[1],1,:) * xs[2]
for TP in [:Diag, :Repeat, :Duplicate]
@eval function OMEinsum.unary_einsum!(::$TP, ix, iy, x::CUDAArrayTypes, y::CUDAArrayTypes, sx, sy)
@debug "cueinsum fallback to loop_einsum" rule ix => iy size(x)
size_dict = OMEinsum.get_size_dict((ix, iy), (x, y))
loop_einsum!((ix,), iy, (x,), y, sx, sy, size_dict)
end
end

Expand All @@ -51,24 +44,25 @@ function loop_einsum!(ixs0, iy0,
ixs = (Tuple.(ixs0)...,)
iy_ = _unique(LT,iy)
NO = length(iy_)
if iszero(sy)
fill!(y, zero(T))
else
lmul!(sy, y)
end
A = einarray(Val(ixs), Val(iy), xs, size_dict)
A = einarray(Val(ixs), Val((iy_...,)), xs, size_dict)
if NO == length(iy)
raw_ = similar(y, (fill(1, ndims(A)-NO)...,size(y)...,))
fill!(raw_, zero(T))
Base.mapreducedim!(x->x, +, raw_, A)
if ndims(A)-NO > 0 # fix 1.7 compatibility
raw = dropdims(raw_, dims=(1:ndims(A)-NO...,))
else
raw = raw_
end
return @addmul! sy * y + sx * raw
return @flatten_addmul! sy * y + sx * raw
else
y_ = CUDA.zeros(T, size(A)[end-NO+1:end]...)
y_ = reshape(y_, fill(1, ndims(A)-NO)...,size(y_)...)
if iszero(sy)
fill!(y, zero(T))
else
lmul!(sy, y)
end
y_ = similar(y, (fill(1, ndims(A)-NO)...,[size_dict[l] for l in iy_]...))
fill!(y_, zero(T))
raw = Base.mapreducedim!(x->x, +, y_, A)
if ndims(A)-NO > 0 # fix 1.7 compatibility
raw = dropdims(raw, dims=(1:ndims(A)-NO...,))
Expand All @@ -91,7 +85,7 @@ function expanddims!(::Val{ixs}, ::Val{iy}, x, y, sx) where {ixs,iy}
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
i > length(x) && return nothing
@inbounds yi = expandind(Val{ixs}(), Val{iy}(), CIS[i].I)
@inbounds y[CartesianIndex(yi)] = sx * x[i]
@inbounds y[CartesianIndex(yi)] += sx * x[i]
nothing
end
@cuda(blocks=nblocks, threads=nthreads, kernel(y, x))
Expand All @@ -118,6 +112,20 @@ function einsum!(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayType
return res
end

function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), size_dict::Dict; active_free=false)
# do not use map because the static overhead is too large
# do not use `setindex!` because we need to make the AD work
mxs = Vector{AbstractArray}(undef, length(siblings(neinsum)))
for (i, arg) in enumerate(siblings(neinsum))
mxs = _safe_set(mxs, i, isleaf(arg) ? xs[tensorindex(arg)] : einsum(arg, xs, size_dict; active_free=active_free))
end
res = einsum(rootcode(neinsum), (mxs...,), size_dict)
active_free && for mx in mxs # free CuArray aggressively.
CUDA.unsafe_free!(mx)
end
return res
end

@info("OMEinsum loaded the CUDA module successfully")

end
12 changes: 6 additions & 6 deletions src/binaryrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ end
# S = N
# T = N
function binary_einsum!(::SimpleBinaryRule{('j',), ('j',), ()}, x1, x2, y, sx, sy)
@addmul! sy * y + sx * Ref(transpose(x1) * x2)
@addmul! sy * y + sx * (reshape(x1, 1, length(x1)) * x2)
end

# ,k->k : 001
# S = N
# T = N
@inline function binary_einsum!(::SimpleBinaryRule{(), ('k',), ('k',)}, x1, x2, y, sx, sy)
binary_einsum!(SimpleBinaryRule{('i',),(),('i',)}(), x2, x1, y, sx, sy)
@addmul! sy * y + sx * Ref(asscalar(x1)) * x2
end

# j,jk->k : 011
Expand All @@ -64,10 +64,10 @@ end
# S = N^2
# T = N^2
function binary_einsum!(::SimpleBinaryRule{('i',), ('k',), ('i','k')}, x1, x2, y, sx, sy)
@addmul! sy * y + sx * x1 * transpose(x2)
@addmul! sy * y + sx * x1 * reshape(x2, 1, length(x2))
end
@inline function binary_einsum!(::SimpleBinaryRule{('i',), ('k',),('k','i')}, x1, x2, y, sx, sy)
@addmul! sy * y + sx * transpose(x1) * x2
@addmul! sy * y + sx * reshape(x1, 1, length(x1)) * x2
end

# 000
Expand All @@ -77,12 +77,12 @@ end

# 100
function binary_einsum!(::SimpleBinaryRule{('i','l'),('l',), ('i','l')}, x1, x2, y, sx, sy)
@addmul! sy * y + sx * x1 * transpose(x2)
@addmul! sy * y + sx * x1 * reshape(x2, 1, length(x2))
end

# 001
@inline function binary_einsum!(::SimpleBinaryRule{('l',), ('k','l'), ('k','l')}, x1, x2, y, sx, sy)
binary_einsum!(SimpleBinaryRule{('i','l'),('l',),('i','l')}(), x2, x1, y, sx, sy)
@addmul! sy * y + sx * reshape(x1, 1, length(x1)) * x2
end

# 010
Expand Down
9 changes: 9 additions & 0 deletions src/einsequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,15 @@ function einsum!(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,AbstractArray
end
return einsum!(rootcode(neinsum), (mxs...,), y, sx, sy, size_dict)
end
function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), size_dict::Dict)
# do not use map because the static overhead is too large
# do not use `setindex!` because we need to make the AD work
mxs = Vector{AbstractArray}(undef, length(siblings(neinsum)))
for (i, arg) in enumerate(siblings(neinsum))
mxs = _safe_set(mxs, i, isleaf(arg) ? xs[tensorindex(arg)] : einsum(arg, xs, size_dict))
end
return einsum(rootcode(neinsum), (mxs...,), size_dict)
end

_safe_set(lst, i, y) = (lst[i] = y; lst)

Expand Down
85 changes: 67 additions & 18 deletions test/cueinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,27 @@ using Zygote

CUDA.allowscalar(false)

@testset "loop einsum" begin
a = [randn(fill(3, i)...) for i=1:4]
ca = a .|> CuArray
ixs = ((1,2), (2,3))
xs = (ca[2], ca[2])
@test loop_einsum!(ixs, (1,3), xs, zeros(3,3) |> CuArray, true, false, OMEinsum.get_size_dict(ixs, xs)) |> Array a[2]*a[2]
@test loop_einsum!(ixs, (1,3), xs, ones(3,3) |> CuArray, 4.0, 2.0, OMEinsum.get_size_dict(ixs, xs)) |> Array 4.0 * a[2]*a[2] + 2 * ones(3, 3)
res = 4.0 * a[2]*a[2]
out = 2 * ones(3, 3, 3)
for k = 1:3
out[:,k,k] .+= res[:,k]
end
@test loop_einsum!(ixs, (1,3,3), xs, ones(3,3,3) |> CuArray, 4.0, 2.0, OMEinsum.get_size_dict(ixs, xs)) |> Array out
end

@testset "cuda einsum" begin
a = [randn(fill(3, i)...) for i=1:4]
ca = a .|> CuArray
ixs = ((1,2), (2,3))
xs = (ca[2], ca[2])
@test loop_einsum!(EinCode(ixs, (1,3)), xs, zeros(3,3) |> CuArray, OMEinsum.get_size_dict(ixs, xs)) ca[2]*ca[2]
@test loop_einsum!(ixs, (1,3), xs, zeros(3,3) |> CuArray, true, false, OMEinsum.get_size_dict(ixs, xs)) ca[2]*ca[2]
for f in [ein"ij,jk->ik", ein"ii->", ein"ijj ->i", ein"ij,ik,il->jkl", ein"ii->i", ein"ijl->i", ein"i->ii", ein"ij,jk,kl->il", ein"ij,ij,ij -> ij"]
cins = map(ix->ca[length(ix)], OMEinsum.getixs(f))
ins = map(ix->a[length(ix)], OMEinsum.getixs(f))
Expand Down Expand Up @@ -56,7 +71,7 @@ end
@test M |> Array loop_einsum(_code, xs, OMEinsum.get_size_dict(OMEinsum.getixs(_code), xs)) |> Array
end

@testset "binary einsum" begin
@testset "unary einsum rules" begin
for code in [
ein"ij->", # sum
ein"ij->j", # sum
Expand All @@ -69,10 +84,58 @@ end
ein"i->ik", # ~sum
ein"->ik", # ~sum
ein"illljkk->kijjcc", # general
]
@info code
D = 2
xs = [length(ix)==0 ? CUDA.fill(1.2) : CUDA.rand(Float64, fill(D, length(ix))...) for ix in OMEinsum.getixs(code)]
size_dict = Dict(zip(('a', 'b', 'c', 'd', 'e', 'f','i','j','k','l'), ntuple(x->D, 10)))
res = einsum(code, (xs...,), size_dict)
@test Array(res) loop_einsum(code, (Array.(xs)...,), size_dict)
@test Array(res) Array(loop_einsum(code, (xs...,), size_dict))
end
end

@testset "binary einsum rules" begin
codes = [
# binary
ein",->",
ein"ijl,jl->il",
ein"ab,bc->ac",
ein"i,->i",
ein"j,j->",
ein",k->k",
ein"j,jk->k",
ein"j,kj->k",
ein"ij,j->i",
ein"ji,j->i",
ein"i,k->ik",
ein"i,k->ki",
]

for (i1, X1) in enumerate([('i', 'j'), ('j', 'i')])
for (i2, X2) in enumerate([('j', 'k'), ('k', 'j')])
for (i3, X3) in enumerate([('i', 'k'), ('k', 'i')])
push!(codes, OMEinsum.StaticEinCode{Char, (X1,X2),X3}())
end
end
end
for code in copy(codes)
X1, X2 = OMEinsum.getixs(code)
X3 = OMEinsum.getiy(code)
push!(codes, OMEinsum.StaticEinCode{Char, ((X1...,'l'),(X2...,'l')),(X3...,'l')}())
end

for code in codes
@info code
D = 2
xs = [length(ix)==0 ? CUDA.fill(1.2) : CUDA.rand(Float64, fill(D, length(ix))...) for ix in OMEinsum.getixs(code)]
size_dict = Dict(zip(('a', 'b', 'c', 'd', 'e', 'f','i','j','k','l'), ntuple(x->D, 10)))
res = einsum(code, (xs...,), size_dict)
@test Array(res) loop_einsum(code, (Array.(xs)...,), size_dict)
@test Array(res) Array(loop_einsum(code, (xs...,), size_dict))
end
end

@testset "composite einsum" begin
for code in [
ein"abb,bc->ac", # with diag in
ein"ab,bc->acc", # with diag out
ein"ab,bce->ac", # with sum in
Expand All @@ -92,20 +155,6 @@ end
end
end

@testset "binary rules" begin
for (code, a, b) in [
(ein"j,j->", randn(10), randn(10)),
(ein"i,->i", randn(10), fill(2.0, ())),
(ein",->", fill(2.0,()), fill(2.0, ())),
(ein"il,kl->ikl", randn(10, 10), randn(10, 10)),
]
res0 = code(a, b)
res1 = code(CuArray(a), CuArray(b))
@test res1 isa CuArray
@test res0 Array(res1)
end
end

@testset "permutedims for high dimensional tensors" begin
c = CUDA.rand(4, [rand(1:3) for _ = 2:18]...);
@test Array(permutedims(c, 18:-1:1)) permutedims(Array(c), 18:-1:1)
Expand Down
2 changes: 2 additions & 0 deletions test/einsequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ using OMEinsum: IndexGroup, NestedEinsum, parse_nested, DynamicEinCode, isleaf,
size_info = Dict('k'=>2)
a, b, c, d = randn(2), randn(2,2), randn(2), randn(2)
@test ein"((i,ij),i),j->ik"(a, b, c, d; size_info=size_info) ein"i,ij,i,j->ik"(a, b, c, d; size_info=size_info)
size_dict = Dict([l=>2 for l in "ijkl"])
@test einsum!(ein"((i,ij),i),j->ik", (a, b, c, d), randn(2, 2), true, false, size_dict) ein"i,ij,i,j->ik"(a, b, c, d; size_info=size_info)
@test getixsv(ein"((i,ij),i),j->ik") == getixsv(ein"i,ij,i,j->ik") == getixsv(DynamicEinCode(ein"i,ij,i,j->ik")) == [['i'], ['i','j'], ['i'], ['j']]
@test getiyv(ein"((i,ij),i),j->ik") == getiyv(ein"i,ij,i,j->ik") == getiyv(DynamicEinCode(ein"i,ij,i,j->ik")) == ['i','k']
end
Expand Down

0 comments on commit f7ad72a

Please sign in to comment.