Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manual AD #170

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions examples/manualad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using OMEinsum
using OMEinsum: cost_and_gradient
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2)
y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))

# evaluate the cost and the gradient of leaves
function gf(code, xs, res, ȳ = OMEinsum.init_gradient(code, xs))
cost, tree = OMEinsum.gradient_tree(code, xs, ȳ)
# extract the gradients on leaves (i.e. the input tensors).
return cost, OMEinsum.extract_leaves!(code, tree, res)
end

using Zygote
xA, xB, xC = randn(2, 3), randn(3, 4), randn(4, 2)
function gfunc(A, B, C)
ȳ = fill(one(eltype(A)))
res = Zygote.Buffer(Any[nothing, nothing, nothing])
cost, (gA, gB, gC) = gf(ein"(ij, jk), ki->", (A, B, C), res, ȳ)
@info "summing"
return sum(gA .* xA) + sum(gB .* xB) + sum(gC .* xC)
end
Zygote.gradient(gfunc, A, B, C)


zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C)
mg = gf(ein"(ij, jk), ki->", (A, B, C), Any[nothing, nothing, nothing])

using FiniteDiff
h = FiniteDiff.finite_difference_hessian(v->ein"(ij, jk), ki->"(reshape(v[1:6], 2, 3), reshape(v[7:18], 3, 4), reshape(v[19:end], 4, 2))[], [vec(A); vec(B); vec(C)])
3 changes: 2 additions & 1 deletion src/OMEinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using OMEinsumContractionOrders
using AbstractTrees
import LinearAlgebra: BlasFloat

export @ein_str, @ein, @ein!, ein
export @ein_str, @ein, @ein!, ein, @optein_str
export einsum!, einsum, dynamic_einsum
export EinCode, EinIndexer, EinArray, DynamicEinCode, StaticEinCode, AbstractEinsum, NestedEinsum, SlicedEinsum, DynamicNestedEinsum, StaticNestedEinsum
export getiyv, getixsv, uniquelabels, labeltype
Expand Down Expand Up @@ -40,6 +40,7 @@ include("interfaces.jl")
include("einsequence.jl")
include("slicing.jl")
include("autodiff.jl")
include("bp.jl")

include("contractionorder.jl")

Expand Down
15 changes: 15 additions & 0 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,18 @@ end
@non_differentiable get_size_dict!(::Any, ::Any, ::Any)
@non_differentiable DynamicEinCode(::Any, ::Any)
@non_differentiable DynamicEinCode(::Any)
@non_differentiable getixsv(::Any)

echo(x; tag="echo") = x
function ChainRulesCore.rrule(::typeof(echo), x; tag="echo")
@info "$tag: $x"
x, function (dy)
@info "$tag (back): x̄ = $dy"
return (NoTangent(), dy)
end
end

macro echo(var)
name = QuoteNode(var)
esc(:($var = $echo($var; tag="$($name)")))
end
164 changes: 164 additions & 0 deletions src/bp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# `CacheTree` stores intermediate `NestedEinsum` contraction results.
# It is a tree structure that isomorphic to the contraction tree,
# `content` is the cached intermediate contraction result.
# `siblings` are the siblings of current node.
struct CacheTree{T}
content::AbstractArray{T}
siblings::Vector{CacheTree{T}}
end
CacheTree(content::AbstractArray{T}, siblings) where T = CacheTree(content, CacheTree{T}[siblings...])

"""
cached_einsum(code, xs, size_dict)

Compute the einsum contraction and cache the intermediate contraction results.

### Arguments
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`.
- `xs`: The input tensors.
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension.

### Returns
- `CacheTree`: The cached tree storing the intermediate results.
"""
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
# slicing is not supported yet.
if length(se.slicing) != 0
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
end
return cached_einsum(se.eins, xs, size_dict)
end

# recursively contract and cache a tensor network
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
if isleaf(code)
# For a leaf node, cache the input tensor
y = xs[tensorindex(code)]
return CacheTree(y, CacheTree{eltype(y)}[])
else
# For a non-leaf node, compute the einsum and cache the contraction result
caches = [cached_einsum(arg, xs, size_dict) for arg in siblings(code)]
# `einsum` evaluates the einsum contraction,
# Its 1st argument is the contraction pattern,
# Its 2nd one is a tuple of input tensors,
# Its 3rd argument is the size dictionary (label as the key, size as the value).
y = einsum(rootcode(code), ntuple(i -> caches[i].content, length(caches)), size_dict)
return CacheTree(y, caches)
end
end

"""
back_propagate(f, code, cache, ȳ, size_dict)

Back propagate the message `ȳ` through the cached tree `cache` and return a tree storing the intermediate messages.
The message can be gradients et al.

### Arguments
- `f`: The back-propagation rule. The signature is `f(eins, xs, y, size_dict, dy) -> dxs`, where
- `eins`: The contraction code at the current node.
- `xs`: The input tensors at the current node.
- `y`: The output tensor at the current node.
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension.
- `dy`: The message on the output tensor (`y`) to back-propagate through the current node.
- `dxs`: The message on the input tensors (`xs`) as the result of back-propagation.
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`.
- `cache`: The cached intermediate results, which can be generated by [`cached_einsum`](@ref).
- `ȳ`: The message to back-propagate.
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension.

### Returns
- `CacheTree`: The tree storing the intermediate messages.
"""
function back_propagate(f, se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
if length(se.slicing) != 0
@warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`."
end
return back_propagate(f, se.eins, cache, dy, size_dict)
end

function back_propagate(f, code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
if isleaf(code)
return CacheTree(dy, CacheTree{T}[])
else
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
# `einsum_grad` is the back-propagation rule for einsum function.
# If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
# Then the back-propagation pass is
# ```
# A̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 1)
# B̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 2)
# ...
# ```
# Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
dxs = f(rootcode(code), xs, cache.content, size_dict, dy)
return CacheTree(dy, back_propagate.(Ref(f), siblings(code), cache.siblings, dxs, Ref(size_dict)))
end
end

# the main function for generating the gradient tree.
function gradient_tree(code::AbstractEinsum, xs, ȳ)
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
size_dict = get_size_dict!(getixsv(code), xs, Dict{labeltype(code), Int}())
# forward compute and cache intermediate results.
cache = cached_einsum(code, xs, size_dict)
# back-propagate
function bprule(eins, @nospecialize(xs), @nospecialize(y), size_dict, @nospecialize(dy))
res = ntuple(i -> einsum_grad(getixs(eins), xs, getiy(eins), size_dict, dy, i), length(xs))
return res
end
return copy(cache.content), back_propagate(bprule, code, cache, ȳ, size_dict)
end

"""
cost_and_gradient(code, xs, ȳ)

Compute the cost and the gradients w.r.t the input tensors `xs`.

### Arguments
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`.
- `xs`: The input tensors.
- `ȳ`: The message to back-propagate. Default is `1`.

### Returns
- `cost`: The cost of the contraction.
- `grads`: The gradients w.r.t the input tensors.
"""
function cost_and_gradient(code, xs, ȳ = nothing)
if ȳ === nothing
ȳ = init_gradient(code, xs)
@assert ndims(ȳ) == 0 "The output must be a scalar! Or you need to feed the gradient manually. Got: $(ndims(ȳ))!"
end
cost, tree = gradient_tree(code, xs, ȳ)
# extract the gradients on leaves (i.e. the input tensors).
return cost, extract_leaves(code, tree)
end

# initialize `y̅` as `1`. Note we always start from `L̅ := 1`.
function init_gradient(code, xs)
size_dict = get_size_dict!(getixsv(code), xs, Dict{labeltype(code), Int}())
output_size = getindex.(Ref(size_dict), getiyv(code))
ȳ = get_output_array(xs, output_size)
return fill!(ȳ, one(eltype(ȳ)))
end

# since slicing is not supported, we forward it to NestedEinsum.
extract_leaves(code::SlicedEinsum, cache::CacheTree) = extract_leaves(code.eins, cache)

# extract gradients on leaf nodes.
function extract_leaves(code::NestedEinsum, cache::CacheTree)
res = Vector{Any}(undef, length(getixsv(code)))
return extract_leaves!(code, cache, res)
end

function extract_leaves!(code, cache, res)
if isleaf(code)
# extract
res[tensorindex(code)] = cache.content
else
# resurse deeper
for (subcode, sib) in zip(siblings(code), cache.siblings)
extract_leaves!(subcode, sib, res)
end
end
return res
end
10 changes: 10 additions & 0 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ macro ein_str(s::AbstractString)
ein(s)
end

"""
optein"ij,jk,kl -> ik"(A, B, C)

String macro interface that similar to [`@ein_str`](@ref), with optimized contraction order (dimensions are assumed to be uniform).
"""
macro optein_str(s::AbstractString)
code = ein(s)
optimize_code(code, uniformsize(code, 20), TreeSA(; ntrials=1, niters=10)).eins
end

function ein(s::AbstractString)
s = replace(replace(s, "\n" => ""), " "=>"")
m = match(r"([\(\)a-z,α-ω]*)->([a-zα-ω]*)", s)
Expand Down
2 changes: 1 addition & 1 deletion src/slicing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} wher
it = SliceIterator(se, size_dict)
res = get_output_array(xs, getindex.(Ref(size_dict), it.iyv))
eins_sliced = drop_slicedim(se.eins, se.slicing)
for slicemap in it
for slicemap in it # `slicemap` is a Dict storing a mapping from sliced_labels to the current slice index
# NOTE: @debug will break Zygote
# @debug "computing slice $k/$(length(it))"
xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs))
Expand Down
17 changes: 17 additions & 0 deletions test/bp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using OMEinsum, Test, Zygote

@testset "bp check" begin
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2)
cost0 = ein"(ij, jk), ki->"(A, B, C)[]
zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C)
cost, mg = OMEinsum.cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))
@test cost[] ≈ cost0
@test all(zg .≈ mg)

code = optimize_code(ein"ij, jk, ki->", uniformsize(ein"ij, jk, ki->", 2), TreeSA())
cost0 = code(A, B, C)[]
zg = Zygote.gradient((a, b, c)->code(a, b, c)[], A, B, C)
cost, mg = OMEinsum.cost_and_gradient(code, (A, B, C))
@test cost[] ≈ cost0
@test all(zg .≈ mg)
end
5 changes: 5 additions & 0 deletions test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ using OMEinsum: get_size_dict
ijk->
ikl" == ein"ijk,ijk->ikl"
end

@testset "opein" begin
code = optein"ij,jk,ki->"
@test code isa NestedEinsum
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ end
include("contractionorder.jl")
end

@testset "back propagation" begin
include("bp.jl")
end

@testset "docstring" begin
Documenter.doctest(OMEinsum; manual=false)
end
Loading