diff --git a/examples/manualad.jl b/examples/manualad.jl new file mode 100644 index 0000000..44bada8 --- /dev/null +++ b/examples/manualad.jl @@ -0,0 +1,29 @@ +using OMEinsum +using OMEinsum: cost_and_gradient +A, B, C = randn(2, 3), randn(3, 4), randn(4, 2) +y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C)) + +# evaluate the cost and the gradient of leaves +function gf(code, xs, res, ȳ = OMEinsum.init_gradient(code, xs)) + cost, tree = OMEinsum.gradient_tree(code, xs, ȳ) + # extract the gradients on leaves (i.e. the input tensors). + return cost, OMEinsum.extract_leaves!(code, tree, res) +end + +using Zygote +xA, xB, xC = randn(2, 3), randn(3, 4), randn(4, 2) +function gfunc(A, B, C) + ȳ = fill(one(eltype(A))) + res = Zygote.Buffer(Any[nothing, nothing, nothing]) + cost, (gA, gB, gC) = gf(ein"(ij, jk), ki->", (A, B, C), res, ȳ) + @info "summing" + return sum(gA .* xA) + sum(gB .* xB) + sum(gC .* xC) +end +Zygote.gradient(gfunc, A, B, C) + + +zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C) +mg = gf(ein"(ij, jk), ki->", (A, B, C), Any[nothing, nothing, nothing]) + +using FiniteDiff +h = FiniteDiff.finite_difference_hessian(v->ein"(ij, jk), ki->"(reshape(v[1:6], 2, 3), reshape(v[7:18], 3, 4), reshape(v[19:end], 4, 2))[], [vec(A); vec(B); vec(C)]) \ No newline at end of file diff --git a/src/OMEinsum.jl b/src/OMEinsum.jl index 0b6ede9..ff90099 100644 --- a/src/OMEinsum.jl +++ b/src/OMEinsum.jl @@ -6,7 +6,7 @@ using OMEinsumContractionOrders using AbstractTrees import LinearAlgebra: BlasFloat -export @ein_str, @ein, @ein!, ein +export @ein_str, @ein, @ein!, ein, @optein_str export einsum!, einsum, dynamic_einsum export EinCode, EinIndexer, EinArray, DynamicEinCode, StaticEinCode, AbstractEinsum, NestedEinsum, SlicedEinsum, DynamicNestedEinsum, StaticNestedEinsum export getiyv, getixsv, uniquelabels, labeltype @@ -40,6 +40,7 @@ include("interfaces.jl") include("einsequence.jl") include("slicing.jl") include("autodiff.jl") +include("bp.jl") include("contractionorder.jl") diff --git a/src/autodiff.jl b/src/autodiff.jl index 433fb0b..3f3ea96 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -52,3 +52,18 @@ end @non_differentiable get_size_dict!(::Any, ::Any, ::Any) @non_differentiable DynamicEinCode(::Any, ::Any) @non_differentiable DynamicEinCode(::Any) +@non_differentiable getixsv(::Any) + +echo(x; tag="echo") = x +function ChainRulesCore.rrule(::typeof(echo), x; tag="echo") + @info "$tag: $x" + x, function (dy) + @info "$tag (back): x̄ = $dy" + return (NoTangent(), dy) + end +end + +macro echo(var) + name = QuoteNode(var) + esc(:($var = $echo($var; tag="$($name)"))) +end \ No newline at end of file diff --git a/src/bp.jl b/src/bp.jl new file mode 100644 index 0000000..17ac2ef --- /dev/null +++ b/src/bp.jl @@ -0,0 +1,164 @@ +# `CacheTree` stores intermediate `NestedEinsum` contraction results. +# It is a tree structure that isomorphic to the contraction tree, +# `content` is the cached intermediate contraction result. +# `siblings` are the siblings of current node. +struct CacheTree{T} + content::AbstractArray{T} + siblings::Vector{CacheTree{T}} +end +CacheTree(content::AbstractArray{T}, siblings) where T = CacheTree(content, CacheTree{T}[siblings...]) + +""" + cached_einsum(code, xs, size_dict) + +Compute the einsum contraction and cache the intermediate contraction results. + +### Arguments +- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`. +- `xs`: The input tensors. +- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension. + +### Returns +- `CacheTree`: The cached tree storing the intermediate results. +""" +function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict) + # slicing is not supported yet. + if length(se.slicing) != 0 + @warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`." + end + return cached_einsum(se.eins, xs, size_dict) +end + +# recursively contract and cache a tensor network +function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict) + if isleaf(code) + # For a leaf node, cache the input tensor + y = xs[tensorindex(code)] + return CacheTree(y, CacheTree{eltype(y)}[]) + else + # For a non-leaf node, compute the einsum and cache the contraction result + caches = [cached_einsum(arg, xs, size_dict) for arg in siblings(code)] + # `einsum` evaluates the einsum contraction, + # Its 1st argument is the contraction pattern, + # Its 2nd one is a tuple of input tensors, + # Its 3rd argument is the size dictionary (label as the key, size as the value). + y = einsum(rootcode(code), ntuple(i -> caches[i].content, length(caches)), size_dict) + return CacheTree(y, caches) + end +end + +""" + back_propagate(f, code, cache, ȳ, size_dict) + +Back propagate the message `ȳ` through the cached tree `cache` and return a tree storing the intermediate messages. +The message can be gradients et al. + +### Arguments +- `f`: The back-propagation rule. The signature is `f(eins, xs, y, size_dict, dy) -> dxs`, where + - `eins`: The contraction code at the current node. + - `xs`: The input tensors at the current node. + - `y`: The output tensor at the current node. + - `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension. + - `dy`: The message on the output tensor (`y`) to back-propagate through the current node. + - `dxs`: The message on the input tensors (`xs`) as the result of back-propagation. +- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`. +- `cache`: The cached intermediate results, which can be generated by [`cached_einsum`](@ref). +- `ȳ`: The message to back-propagate. +- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension. + +### Returns +- `CacheTree`: The tree storing the intermediate messages. +""" +function back_propagate(f, se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T} + if length(se.slicing) != 0 + @warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`." + end + return back_propagate(f, se.eins, cache, dy, size_dict) +end + +function back_propagate(f, code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T} + if isleaf(code) + return CacheTree(dy, CacheTree{T}[]) + else + xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings)) + # `einsum_grad` is the back-propagation rule for einsum function. + # If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)` + # Then the back-propagation pass is + # ``` + # A̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 1) + # B̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 2) + # ... + # ``` + # Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`... + dxs = f(rootcode(code), xs, cache.content, size_dict, dy) + return CacheTree(dy, back_propagate.(Ref(f), siblings(code), cache.siblings, dxs, Ref(size_dict))) + end +end + +# the main function for generating the gradient tree. +function gradient_tree(code::AbstractEinsum, xs, ȳ) + # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary. + size_dict = get_size_dict!(getixsv(code), xs, Dict{labeltype(code), Int}()) + # forward compute and cache intermediate results. + cache = cached_einsum(code, xs, size_dict) + # back-propagate + function bprule(eins, @nospecialize(xs), @nospecialize(y), size_dict, @nospecialize(dy)) + res = ntuple(i -> einsum_grad(getixs(eins), xs, getiy(eins), size_dict, dy, i), length(xs)) + return res + end + return copy(cache.content), back_propagate(bprule, code, cache, ȳ, size_dict) +end + +""" + cost_and_gradient(code, xs, ȳ) + +Compute the cost and the gradients w.r.t the input tensors `xs`. + +### Arguments +- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`. +- `xs`: The input tensors. +- `ȳ`: The message to back-propagate. Default is `1`. + +### Returns +- `cost`: The cost of the contraction. +- `grads`: The gradients w.r.t the input tensors. +""" +function cost_and_gradient(code, xs, ȳ = nothing) + if ȳ === nothing + ȳ = init_gradient(code, xs) + @assert ndims(ȳ) == 0 "The output must be a scalar! Or you need to feed the gradient manually. Got: $(ndims(ȳ))!" + end + cost, tree = gradient_tree(code, xs, ȳ) + # extract the gradients on leaves (i.e. the input tensors). + return cost, extract_leaves(code, tree) +end + +# initialize `y̅` as `1`. Note we always start from `L̅ := 1`. +function init_gradient(code, xs) + size_dict = get_size_dict!(getixsv(code), xs, Dict{labeltype(code), Int}()) + output_size = getindex.(Ref(size_dict), getiyv(code)) + ȳ = get_output_array(xs, output_size) + return fill!(ȳ, one(eltype(ȳ))) +end + +# since slicing is not supported, we forward it to NestedEinsum. +extract_leaves(code::SlicedEinsum, cache::CacheTree) = extract_leaves(code.eins, cache) + +# extract gradients on leaf nodes. +function extract_leaves(code::NestedEinsum, cache::CacheTree) + res = Vector{Any}(undef, length(getixsv(code))) + return extract_leaves!(code, cache, res) +end + +function extract_leaves!(code, cache, res) + if isleaf(code) + # extract + res[tensorindex(code)] = cache.content + else + # resurse deeper + for (subcode, sib) in zip(siblings(code), cache.siblings) + extract_leaves!(subcode, sib, res) + end + end + return res +end \ No newline at end of file diff --git a/src/interfaces.jl b/src/interfaces.jl index c672069..b04dfc1 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -22,6 +22,16 @@ macro ein_str(s::AbstractString) ein(s) end +""" + optein"ij,jk,kl -> ik"(A, B, C) + +String macro interface that similar to [`@ein_str`](@ref), with optimized contraction order (dimensions are assumed to be uniform). +""" +macro optein_str(s::AbstractString) + code = ein(s) + optimize_code(code, uniformsize(code, 20), TreeSA(; ntrials=1, niters=10)).eins +end + function ein(s::AbstractString) s = replace(replace(s, "\n" => ""), " "=>"") m = match(r"([\(\)a-z,α-ω]*)->([a-zα-ω]*)", s) diff --git a/src/slicing.jl b/src/slicing.jl index bc5f7d8..18beb57 100644 --- a/src/slicing.jl +++ b/src/slicing.jl @@ -116,7 +116,7 @@ function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} wher it = SliceIterator(se, size_dict) res = get_output_array(xs, getindex.(Ref(size_dict), it.iyv)) eins_sliced = drop_slicedim(se.eins, se.slicing) - for slicemap in it + for slicemap in it # `slicemap` is a Dict storing a mapping from sliced_labels to the current slice index # NOTE: @debug will break Zygote # @debug "computing slice $k/$(length(it))" xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs)) diff --git a/test/bp.jl b/test/bp.jl new file mode 100644 index 0000000..f82d614 --- /dev/null +++ b/test/bp.jl @@ -0,0 +1,17 @@ +using OMEinsum, Test, Zygote + +@testset "bp check" begin + A, B, C = randn(2, 3), randn(3, 4), randn(4, 2) + cost0 = ein"(ij, jk), ki->"(A, B, C)[] + zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C) + cost, mg = OMEinsum.cost_and_gradient(ein"(ij, jk), ki->", (A, B, C)) + @test cost[] ≈ cost0 + @test all(zg .≈ mg) + + code = optimize_code(ein"ij, jk, ki->", uniformsize(ein"ij, jk, ki->", 2), TreeSA()) + cost0 = code(A, B, C)[] + zg = Zygote.gradient((a, b, c)->code(a, b, c)[], A, B, C) + cost, mg = OMEinsum.cost_and_gradient(code, (A, B, C)) + @test cost[] ≈ cost0 + @test all(zg .≈ mg) +end \ No newline at end of file diff --git a/test/interfaces.jl b/test/interfaces.jl index 9556217..3309b96 100644 --- a/test/interfaces.jl +++ b/test/interfaces.jl @@ -14,3 +14,8 @@ using OMEinsum: get_size_dict ijk-> ikl" == ein"ijk,ijk->ikl" end + +@testset "opein" begin + code = optein"ij,jk,ki->" + @test code isa NestedEinsum +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 10a6aea..e731b50 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,6 +55,10 @@ end include("contractionorder.jl") end +@testset "back propagation" begin + include("bp.jl") +end + @testset "docstring" begin Documenter.doctest(OMEinsum; manual=false) end