Skip to content

Commit

Permalink
Add the inplace version of the einsum function (#159)
Browse files Browse the repository at this point in the history
* save inplace omeinsum

* add inplace binary rules

* fix docstrings

* unitary - initial

* save

* binary rule pass

* rename api

* tested unary rule

* new einsum.jl

* update

* update

* save

* update, no obvious bug

* benchmark perm

* update

* inplace nested einsum

* better compiling

* fix cuda

* fix bugs

* cuda einsum

* improve test coverage
  • Loading branch information
GiggleLiu authored Jan 13, 2024
1 parent 5bc5cc8 commit 34b8d50
Show file tree
Hide file tree
Showing 21 changed files with 916 additions and 599 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,20 @@ julia> gradient(x->optcode(x,s,s,s,s,s,s,s,s,s)[], s)[1] |> sum
```
This tells us that even if we allow duplicates on one vertex, there are no 3-colourings for the peterson graph.

## Comparison with other packages
Similar packages include:
- [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl) and [TensorKit.jl](https://github.com/Jutho/TensorKit.jl)
- [ITensors.jl](https://github.com/ITensor/ITensors.jl)

Comparing with the above packages, `OMEinsum` is optimized over large scale tensor network (or einsum, sum-product network) contraction. Its main advantages are:
- `OMEinsum` has better support to very high dimensional tensor networks and their contraction order.
- `OMEinsum` allows an index to appear multiple times.
- `OMEinsum` has well tested generic element type support.

However, `OMEinsum` also has some disadvantages:
- `OMEinsum` does not support good quantum numbers.
- `OMEinsum` has less optimization on small scale problems.

## Contribute

Suggestions and Comments in the _Issues_ are welcome.
Expand Down
3 changes: 3 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
10 changes: 10 additions & 0 deletions benchmark/large.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using OMEinsum, BenchmarkTools

function benchmark_tensorpermute()
# tensorpermute
A = randn(fill(2, 28)...);
C = zero(A);
perm = [18 22 11 21 15 9 10 19 24 14 5 1 17 20 26 25 28 27 7 6 3 13 12 16 8 23 2 4] |> vec
@btime OMEinsum.tensorpermute!($C, $A, $perm, true, false) evals=3
nothing
end
97 changes: 53 additions & 44 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module CUDAExt

import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm, asscalar
import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm!, asscalar, @flatten_addmul!
using OMEinsum: EinArray, Diag, Repeat, Duplicate, DefaultRule, EinCode, DynamicEinCode, StaticEinCode, NestedEinsum, SimpleBinaryRule, match_rule, loop_einsum, getiy, getixs, _unique, einarray, align_eltypes, siblings, isleaf, tensorindex, _safe_set, rootcode
import OMEinsum
using LinearAlgebra
Expand All @@ -15,54 +15,64 @@ _unwrap(x::CuArray) = x

asarray(x, arr::CuArray) = CuArray(fill(x, ()))
asarray(x::AbstractArray, y::CuArray) = x
asscalar(x::DenseCuArray) = Array(x)[]
asscalar(x::CUDAArrayTypes) = Array(x)[]

# to avoid returning a ReshapedArray
OMEinsum.safe_reshape(x::CuArray, sz) = reshape(x, (sz...,))
OMEinsum.safe_reshape(x::Adjoint{T, <:CuArray{T}} where T, sz) = reshape(CuArray(x), (sz...,))
OMEinsum.safe_reshape(x::Transpose{T, <:CuArray{T}} where T, sz) = reshape(CuArray(x), (sz...,))

Base.Array(x::Base.ReshapedArray{T,0,<:CuArray}) where T = Array(x.parent)

function get_output_array(xs::NTuple{N, DenseCuArray{<:Any,M} where M}, size; has_repeated_indices=true) where N
function get_output_array(xs::NTuple{N, CUDAArrayTypes{<:Any,M} where M}, size; fillzero=true) where N
CUDA.zeros(promote_type(map(eltype,xs)...), size...)
end
function get_output_array(xs::NTuple{N, DenseCuArray{T,M} where M}, size; has_repeated_indices=true) where {T,N}
function get_output_array(xs::NTuple{N, CUDAArrayTypes{T,M} where M}, size; fillzero=true) where {T,N}
CUDA.zeros(T, size...)
end

CUDA.cudaconvert(A::EinArray{T}) where T = EinArray{T}(cudaconvert.(A.xs), A.x_indexers, A.y_indexer, A.size, A.ICIS, A.OCIS)
CUDA.cu(A::EinArray{T}) where T = EinArray{T}(cu.(A.xs), A.x_indexers, A.y_indexer, A.size, A.ICIS, A.OCIS)

for TP in [:Diag, :Repeat, :Duplicate, :DefaultRule]
@eval function einsum(::$TP, ixs, iy, xs::Tuple{<:DenseCuArray}, size_dict::Dict{LT}) where LT
@debug "cueinsum fallback to loop_einsum" rule ixs => iy size.(xs)
loop_einsum(EinCode(ixs, iy), xs, size_dict)
for TP in [:Diag, :Repeat, :Duplicate]
@eval function OMEinsum.unary_einsum!(::$TP, ix, iy, x::CUDAArrayTypes, y::CUDAArrayTypes, sx, sy)
@debug "cueinsum fallback to loop_einsum" rule ix => iy size(x)
size_dict = OMEinsum.get_size_dict((ix, iy), (x, y))
loop_einsum!((ix,), iy, (x,), y, sx, sy, size_dict)
end
end

function einsum(::SimpleBinaryRule{('j',), ('j',), ()}, xs::NTuple{2, DenseCuArray})
dropdims(reshape(xs[1],1,:) * xs[2]; dims=1)
end

function loop_einsum!(code::EinCode,
xs::NTuple{N, DenseCuArray{<:Any,M} where M},
y::DenseCuArray{T,L}, size_dict::Dict{LT}) where {N,L,T, LT}
iy = (getiy(code)...,)
ixs = (Tuple.(getixs(code))...,)
function loop_einsum!(ixs0, iy0,
xs::NTuple{N, CUDAArrayTypes{<:Any,M} where M},
y::CUDAArrayTypes{T,L}, sx, sy, size_dict::Dict{LT}) where {N,L,T, LT}
iy = (iy0...,)
ixs = (Tuple.(ixs0)...,)
iy_ = _unique(LT,iy)
NO = length(iy_)
A = einarray(Val(ixs), Val(iy), xs, size_dict)
A = einarray(Val(ixs), Val((iy_...,)), xs, size_dict)
if NO == length(iy)
y = reshape(y, fill(1, ndims(A)-NO)...,size(y)...)
raw = Base.mapreducedim!(x->x, +, y, A)
raw_ = similar(y, (fill(1, ndims(A)-NO)...,size(y)...,))
fill!(raw_, zero(T))
Base.mapreducedim!(x->x, +, raw_, A)
if ndims(A)-NO > 0 # fix 1.7 compatibility
raw = dropdims(raw, dims=(1:ndims(A)-NO...,))
raw = dropdims(raw_, dims=(1:ndims(A)-NO...,))
else
raw = raw_
end
return raw
return @flatten_addmul! sy * y + sx * raw
else
y_ = CUDA.zeros(T, size(A)[end-NO+1:end]...)
y_ = reshape(y_, fill(1, ndims(A)-NO)...,size(y_)...)
if iszero(sy)
fill!(y, zero(T))
else
lmul!(sy, y)
end
y_ = similar(y, (fill(1, ndims(A)-NO)...,[size_dict[l] for l in iy_]...))
fill!(y_, zero(T))
raw = Base.mapreducedim!(x->x, +, y_, A)
if ndims(A)-NO > 0 # fix 1.7 compatibility
raw = dropdims(raw, dims=(1:ndims(A)-NO...,))
end
return expanddims!(Val{((iy_...,),)}(), Val{iy}(), raw, y)
return expanddims!(Val{((iy_...,),)}(), Val{iy}(), raw, y, sx)
end
end

Expand All @@ -72,31 +82,41 @@ end
Expr(:tuple, ids...)
end

function expanddims!(::Val{ixs}, ::Val{iy}, x, y) where {ixs,iy}
function expanddims!(::Val{ixs}, ::Val{iy}, x, y, sx) where {ixs,iy}
nthreads = 256
nblocks = cld(prod(size(x)), nthreads)
CIS = CartesianIndices(x)
@inline function kernel(y, x)
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
i > length(x) && return nothing
@inbounds yi = expandind(Val{ixs}(), Val{iy}(), CIS[i].I)
@inbounds y[CartesianIndex(yi)] = x[i]
@inbounds y[CartesianIndex(yi)] += sx * x[i]
nothing
end
@cuda(blocks=nblocks, threads=nthreads, kernel(y, x))
return y
end

function _batched_gemm(C1::Char, C2::Char, A::DenseCuArray{T1,3}, B::DenseCuArray{T2,3}) where {T1<:CuBlasFloat, T2<:CuBlasFloat}
CUDA.CUBLAS.gemm_strided_batched(C1, C2, align_eltypes(A,B)...)
end

function einsum(::SimpleBinaryRule{(),(), ()}, xs::NTuple{2, DenseCuArray})
asarray(Array(xs[1])[] * Array(xs[2])[], xs[1])
function _batched_gemm!(C1::Char, C2::Char, alpha, A::CUDAArrayTypes{T1,3}, B::CUDAArrayTypes{T2,3}, beta, C::CUDAArrayTypes{T3,3}) where {T1<:CuBlasFloat, T2<:CuBlasFloat, T3<:CuBlasFloat}
CUDA.CUBLAS.gemm_strided_batched!(C1, C2, alpha, T1 == T3 ? A : T3.(A), T2 == T3 ? B : T3.(B), beta, C)
end

Base.ndims(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{0}}) = 0

function einsum!(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), @nospecialize(y::CUDAArrayTypes), sx, sy, size_dict::Dict; active_free=false)
# do not use map because the static overhead is too large
# do not use `setindex!` because we need to make the AD work
mxs = Vector{AbstractArray}(undef, length(siblings(neinsum)))
for (i, arg) in enumerate(siblings(neinsum))
mxs = _safe_set(mxs, i, isleaf(arg) ? xs[tensorindex(arg)] : einsum(arg, xs, similar(y, ([size_dict[l] for l in getiy(rootcode(arg))]...,)), true, false, size_dict; active_free=active_free))
end
res = einsum!(rootcode(neinsum), (mxs...,), y, sx, sy, size_dict)
active_free && for mx in mxs # free CuArray aggressively.
CUDA.unsafe_free!(mx)
end
return res
end

function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), size_dict::Dict; active_free=false)
# do not use map because the static overhead is too large
# do not use `setindex!` because we need to make the AD work
Expand All @@ -111,17 +131,6 @@ function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes
return res
end

# to dispatch Adjoint correctly
@generated function einsum(code::StaticEinCode{LT,ixs, iy}, xs::NTuple{N,CUDAArrayTypes} where N, size_dict::Dict{LT}) where {LT, ixs, iy}
rule = match_rule(ixs, iy)
:(einsum($rule, $ixs, $iy, _unwrap.(xs), size_dict))
end

function einsum(code::DynamicEinCode, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), size_dict::Dict)
rule = match_rule(getixs(code), getiy(code))
einsum(rule, getixs(code), getiy(code), _unwrap.(xs), size_dict)
end

@info("OMEinsum loaded the CUDA module successfully")

end
4 changes: 3 additions & 1 deletion src/OMEinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using AbstractTrees
import LinearAlgebra: BlasFloat

export @ein_str, @ein, ein
export einsum, dynamic_einsum
export einsum!, einsum, dynamic_einsum
export EinCode, EinIndexer, EinArray, DynamicEinCode, StaticEinCode, AbstractEinsum, NestedEinsum, SlicedEinsum, DynamicNestedEinsum, StaticNestedEinsum
export getiyv, getixsv, uniquelabels, labeltype
export flop
Expand All @@ -33,6 +33,8 @@ include("utils.jl")

include("unaryrules.jl")
include("binaryrules.jl")
include("matchrule.jl")
include("einsum.jl")

include("interfaces.jl")
include("einsequence.jl")
Expand Down
Loading

0 comments on commit 34b8d50

Please sign in to comment.