Skip to content

Commit

Permalink
Manual AD (#170)
Browse files Browse the repository at this point in the history
* save

* update
  • Loading branch information
GiggleLiu authored Jun 25, 2024
1 parent 329d988 commit 2d6627f
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 2 deletions.
29 changes: 29 additions & 0 deletions examples/manualad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using OMEinsum
using OMEinsum: cost_and_gradient
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2)
y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))

# evaluate the cost and the gradient of leaves
function gf(code, xs, res, ȳ = OMEinsum.init_gradient(code, xs))
cost, tree = OMEinsum.gradient_tree(code, xs, ȳ)
# extract the gradients on leaves (i.e. the input tensors).
return cost, OMEinsum.extract_leaves!(code, tree, res)
end

using Zygote
xA, xB, xC = randn(2, 3), randn(3, 4), randn(4, 2)
function gfunc(A, B, C)
= fill(one(eltype(A)))
res = Zygote.Buffer(Any[nothing, nothing, nothing])
cost, (gA, gB, gC) = gf(ein"(ij, jk), ki->", (A, B, C), res, ȳ)
@info "summing"
return sum(gA .* xA) + sum(gB .* xB) + sum(gC .* xC)
end
Zygote.gradient(gfunc, A, B, C)


zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C)
mg = gf(ein"(ij, jk), ki->", (A, B, C), Any[nothing, nothing, nothing])

using FiniteDiff
h = FiniteDiff.finite_difference_hessian(v->ein"(ij, jk), ki->"(reshape(v[1:6], 2, 3), reshape(v[7:18], 3, 4), reshape(v[19:end], 4, 2))[], [vec(A); vec(B); vec(C)])
3 changes: 2 additions & 1 deletion src/OMEinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using OMEinsumContractionOrders
using AbstractTrees
import LinearAlgebra: BlasFloat

export @ein_str, @ein, @ein!, ein
export @ein_str, @ein, @ein!, ein, @optein_str
export einsum!, einsum, dynamic_einsum
export EinCode, EinIndexer, EinArray, DynamicEinCode, StaticEinCode, AbstractEinsum, NestedEinsum, SlicedEinsum, DynamicNestedEinsum, StaticNestedEinsum
export getiyv, getixsv, uniquelabels, labeltype
Expand Down Expand Up @@ -40,6 +40,7 @@ include("interfaces.jl")
include("einsequence.jl")
include("slicing.jl")
include("autodiff.jl")
include("bp.jl")

include("contractionorder.jl")

Expand Down
15 changes: 15 additions & 0 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,18 @@ end
@non_differentiable get_size_dict!(::Any, ::Any, ::Any)
@non_differentiable DynamicEinCode(::Any, ::Any)
@non_differentiable DynamicEinCode(::Any)
@non_differentiable getixsv(::Any)

echo(x; tag="echo") = x
function ChainRulesCore.rrule(::typeof(echo), x; tag="echo")
@info "$tag: $x"
x, function (dy)
@info "$tag (back): x̄ = $dy"
return (NoTangent(), dy)
end
end

macro echo(var)
name = QuoteNode(var)
esc(:($var = $echo($var; tag="$($name)")))
end
164 changes: 164 additions & 0 deletions src/bp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# `CacheTree` stores intermediate `NestedEinsum` contraction results.
# It is a tree structure that isomorphic to the contraction tree,
# `content` is the cached intermediate contraction result.
# `siblings` are the siblings of current node.
struct CacheTree{T}
content::AbstractArray{T}
siblings::Vector{CacheTree{T}}
end
CacheTree(content::AbstractArray{T}, siblings) where T = CacheTree(content, CacheTree{T}[siblings...])

"""
cached_einsum(code, xs, size_dict)
Compute the einsum contraction and cache the intermediate contraction results.
### Arguments
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`.
- `xs`: The input tensors.
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension.
### Returns
- `CacheTree`: The cached tree storing the intermediate results.
"""
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
# slicing is not supported yet.
if length(se.slicing) != 0
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
end
return cached_einsum(se.eins, xs, size_dict)
end

# recursively contract and cache a tensor network
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
if isleaf(code)
# For a leaf node, cache the input tensor
y = xs[tensorindex(code)]
return CacheTree(y, CacheTree{eltype(y)}[])
else
# For a non-leaf node, compute the einsum and cache the contraction result
caches = [cached_einsum(arg, xs, size_dict) for arg in siblings(code)]
# `einsum` evaluates the einsum contraction,
# Its 1st argument is the contraction pattern,
# Its 2nd one is a tuple of input tensors,
# Its 3rd argument is the size dictionary (label as the key, size as the value).
y = einsum(rootcode(code), ntuple(i -> caches[i].content, length(caches)), size_dict)
return CacheTree(y, caches)
end
end

"""
back_propagate(f, code, cache, ȳ, size_dict)
Back propagate the message `ȳ` through the cached tree `cache` and return a tree storing the intermediate messages.
The message can be gradients et al.
### Arguments
- `f`: The back-propagation rule. The signature is `f(eins, xs, y, size_dict, dy) -> dxs`, where
- `eins`: The contraction code at the current node.
- `xs`: The input tensors at the current node.
- `y`: The output tensor at the current node.
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension.
- `dy`: The message on the output tensor (`y`) to back-propagate through the current node.
- `dxs`: The message on the input tensors (`xs`) as the result of back-propagation.
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`.
- `cache`: The cached intermediate results, which can be generated by [`cached_einsum`](@ref).
- `ȳ`: The message to back-propagate.
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension.
### Returns
- `CacheTree`: The tree storing the intermediate messages.
"""
function back_propagate(f, se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
if length(se.slicing) != 0
@warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`."
end
return back_propagate(f, se.eins, cache, dy, size_dict)
end

function back_propagate(f, code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
if isleaf(code)
return CacheTree(dy, CacheTree{T}[])
else
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
# `einsum_grad` is the back-propagation rule for einsum function.
# If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
# Then the back-propagation pass is
# ```
# A̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 1)
# B̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 2)
# ...
# ```
# Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
dxs = f(rootcode(code), xs, cache.content, size_dict, dy)
return CacheTree(dy, back_propagate.(Ref(f), siblings(code), cache.siblings, dxs, Ref(size_dict)))
end
end

# the main function for generating the gradient tree.
function gradient_tree(code::AbstractEinsum, xs, ȳ)
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
size_dict = get_size_dict!(getixsv(code), xs, Dict{labeltype(code), Int}())
# forward compute and cache intermediate results.
cache = cached_einsum(code, xs, size_dict)
# back-propagate
function bprule(eins, @nospecialize(xs), @nospecialize(y), size_dict, @nospecialize(dy))
res = ntuple(i -> einsum_grad(getixs(eins), xs, getiy(eins), size_dict, dy, i), length(xs))
return res
end
return copy(cache.content), back_propagate(bprule, code, cache, ȳ, size_dict)
end

"""
cost_and_gradient(code, xs, ȳ)
Compute the cost and the gradients w.r.t the input tensors `xs`.
### Arguments
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`.
- `xs`: The input tensors.
- `ȳ`: The message to back-propagate. Default is `1`.
### Returns
- `cost`: The cost of the contraction.
- `grads`: The gradients w.r.t the input tensors.
"""
function cost_and_gradient(code, xs, ȳ = nothing)
if=== nothing
= init_gradient(code, xs)
@assert ndims(ȳ) == 0 "The output must be a scalar! Or you need to feed the gradient manually. Got: $(ndims(ȳ))!"
end
cost, tree = gradient_tree(code, xs, ȳ)
# extract the gradients on leaves (i.e. the input tensors).
return cost, extract_leaves(code, tree)
end

# initialize `y̅` as `1`. Note we always start from `L̅ := 1`.
function init_gradient(code, xs)
size_dict = get_size_dict!(getixsv(code), xs, Dict{labeltype(code), Int}())
output_size = getindex.(Ref(size_dict), getiyv(code))
= get_output_array(xs, output_size)
return fill!(ȳ, one(eltype(ȳ)))
end

# since slicing is not supported, we forward it to NestedEinsum.
extract_leaves(code::SlicedEinsum, cache::CacheTree) = extract_leaves(code.eins, cache)

# extract gradients on leaf nodes.
function extract_leaves(code::NestedEinsum, cache::CacheTree)
res = Vector{Any}(undef, length(getixsv(code)))
return extract_leaves!(code, cache, res)
end

function extract_leaves!(code, cache, res)
if isleaf(code)
# extract
res[tensorindex(code)] = cache.content
else
# resurse deeper
for (subcode, sib) in zip(siblings(code), cache.siblings)
extract_leaves!(subcode, sib, res)
end
end
return res
end
10 changes: 10 additions & 0 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ macro ein_str(s::AbstractString)
ein(s)
end

"""
optein"ij,jk,kl -> ik"(A, B, C)
String macro interface that similar to [`@ein_str`](@ref), with optimized contraction order (dimensions are assumed to be uniform).
"""
macro optein_str(s::AbstractString)
code = ein(s)
optimize_code(code, uniformsize(code, 20), TreeSA(; ntrials=1, niters=10)).eins
end

function ein(s::AbstractString)
s = replace(replace(s, "\n" => ""), " "=>"")
m = match(r"([\(\)a-z,α-ω]*)->([a-zα-ω]*)", s)
Expand Down
2 changes: 1 addition & 1 deletion src/slicing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} wher
it = SliceIterator(se, size_dict)
res = get_output_array(xs, getindex.(Ref(size_dict), it.iyv))
eins_sliced = drop_slicedim(se.eins, se.slicing)
for slicemap in it
for slicemap in it # `slicemap` is a Dict storing a mapping from sliced_labels to the current slice index
# NOTE: @debug will break Zygote
# @debug "computing slice $k/$(length(it))"
xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs))
Expand Down
17 changes: 17 additions & 0 deletions test/bp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using OMEinsum, Test, Zygote

@testset "bp check" begin
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2)
cost0 = ein"(ij, jk), ki->"(A, B, C)[]
zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C)
cost, mg = OMEinsum.cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))
@test cost[] cost0
@test all(zg .≈ mg)

code = optimize_code(ein"ij, jk, ki->", uniformsize(ein"ij, jk, ki->", 2), TreeSA())
cost0 = code(A, B, C)[]
zg = Zygote.gradient((a, b, c)->code(a, b, c)[], A, B, C)
cost, mg = OMEinsum.cost_and_gradient(code, (A, B, C))
@test cost[] cost0
@test all(zg .≈ mg)
end
5 changes: 5 additions & 0 deletions test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ using OMEinsum: get_size_dict
ijk->
ikl" == ein"ijk,ijk->ikl"
end

@testset "opein" begin
code = optein"ij,jk,ki->"
@test code isa NestedEinsum
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ end
include("contractionorder.jl")
end

@testset "back propagation" begin
include("bp.jl")
end

@testset "docstring" begin
Documenter.doctest(OMEinsum; manual=false)
end

0 comments on commit 2d6627f

Please sign in to comment.