-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* save * update
- Loading branch information
Showing
9 changed files
with
247 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
using OMEinsum | ||
using OMEinsum: cost_and_gradient | ||
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2) | ||
y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C)) | ||
|
||
# evaluate the cost and the gradient of leaves | ||
function gf(code, xs, res, ȳ = OMEinsum.init_gradient(code, xs)) | ||
cost, tree = OMEinsum.gradient_tree(code, xs, ȳ) | ||
# extract the gradients on leaves (i.e. the input tensors). | ||
return cost, OMEinsum.extract_leaves!(code, tree, res) | ||
end | ||
|
||
using Zygote | ||
xA, xB, xC = randn(2, 3), randn(3, 4), randn(4, 2) | ||
function gfunc(A, B, C) | ||
ȳ = fill(one(eltype(A))) | ||
res = Zygote.Buffer(Any[nothing, nothing, nothing]) | ||
cost, (gA, gB, gC) = gf(ein"(ij, jk), ki->", (A, B, C), res, ȳ) | ||
@info "summing" | ||
return sum(gA .* xA) + sum(gB .* xB) + sum(gC .* xC) | ||
end | ||
Zygote.gradient(gfunc, A, B, C) | ||
|
||
|
||
zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C) | ||
mg = gf(ein"(ij, jk), ki->", (A, B, C), Any[nothing, nothing, nothing]) | ||
|
||
using FiniteDiff | ||
h = FiniteDiff.finite_difference_hessian(v->ein"(ij, jk), ki->"(reshape(v[1:6], 2, 3), reshape(v[7:18], 3, 4), reshape(v[19:end], 4, 2))[], [vec(A); vec(B); vec(C)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# `CacheTree` stores intermediate `NestedEinsum` contraction results. | ||
# It is a tree structure that isomorphic to the contraction tree, | ||
# `content` is the cached intermediate contraction result. | ||
# `siblings` are the siblings of current node. | ||
struct CacheTree{T} | ||
content::AbstractArray{T} | ||
siblings::Vector{CacheTree{T}} | ||
end | ||
CacheTree(content::AbstractArray{T}, siblings) where T = CacheTree(content, CacheTree{T}[siblings...]) | ||
|
||
""" | ||
cached_einsum(code, xs, size_dict) | ||
Compute the einsum contraction and cache the intermediate contraction results. | ||
### Arguments | ||
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`. | ||
- `xs`: The input tensors. | ||
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension. | ||
### Returns | ||
- `CacheTree`: The cached tree storing the intermediate results. | ||
""" | ||
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict) | ||
# slicing is not supported yet. | ||
if length(se.slicing) != 0 | ||
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`." | ||
end | ||
return cached_einsum(se.eins, xs, size_dict) | ||
end | ||
|
||
# recursively contract and cache a tensor network | ||
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict) | ||
if isleaf(code) | ||
# For a leaf node, cache the input tensor | ||
y = xs[tensorindex(code)] | ||
return CacheTree(y, CacheTree{eltype(y)}[]) | ||
else | ||
# For a non-leaf node, compute the einsum and cache the contraction result | ||
caches = [cached_einsum(arg, xs, size_dict) for arg in siblings(code)] | ||
# `einsum` evaluates the einsum contraction, | ||
# Its 1st argument is the contraction pattern, | ||
# Its 2nd one is a tuple of input tensors, | ||
# Its 3rd argument is the size dictionary (label as the key, size as the value). | ||
y = einsum(rootcode(code), ntuple(i -> caches[i].content, length(caches)), size_dict) | ||
return CacheTree(y, caches) | ||
end | ||
end | ||
|
||
""" | ||
back_propagate(f, code, cache, ȳ, size_dict) | ||
Back propagate the message `ȳ` through the cached tree `cache` and return a tree storing the intermediate messages. | ||
The message can be gradients et al. | ||
### Arguments | ||
- `f`: The back-propagation rule. The signature is `f(eins, xs, y, size_dict, dy) -> dxs`, where | ||
- `eins`: The contraction code at the current node. | ||
- `xs`: The input tensors at the current node. | ||
- `y`: The output tensor at the current node. | ||
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension. | ||
- `dy`: The message on the output tensor (`y`) to back-propagate through the current node. | ||
- `dxs`: The message on the input tensors (`xs`) as the result of back-propagation. | ||
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`. | ||
- `cache`: The cached intermediate results, which can be generated by [`cached_einsum`](@ref). | ||
- `ȳ`: The message to back-propagate. | ||
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension. | ||
### Returns | ||
- `CacheTree`: The tree storing the intermediate messages. | ||
""" | ||
function back_propagate(f, se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T} | ||
if length(se.slicing) != 0 | ||
@warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`." | ||
end | ||
return back_propagate(f, se.eins, cache, dy, size_dict) | ||
end | ||
|
||
function back_propagate(f, code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T} | ||
if isleaf(code) | ||
return CacheTree(dy, CacheTree{T}[]) | ||
else | ||
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings)) | ||
# `einsum_grad` is the back-propagation rule for einsum function. | ||
# If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)` | ||
# Then the back-propagation pass is | ||
# ``` | ||
# A̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 1) | ||
# B̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 2) | ||
# ... | ||
# ``` | ||
# Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`... | ||
dxs = f(rootcode(code), xs, cache.content, size_dict, dy) | ||
return CacheTree(dy, back_propagate.(Ref(f), siblings(code), cache.siblings, dxs, Ref(size_dict))) | ||
end | ||
end | ||
|
||
# the main function for generating the gradient tree. | ||
function gradient_tree(code::AbstractEinsum, xs, ȳ) | ||
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary. | ||
size_dict = get_size_dict!(getixsv(code), xs, Dict{labeltype(code), Int}()) | ||
# forward compute and cache intermediate results. | ||
cache = cached_einsum(code, xs, size_dict) | ||
# back-propagate | ||
function bprule(eins, @nospecialize(xs), @nospecialize(y), size_dict, @nospecialize(dy)) | ||
res = ntuple(i -> einsum_grad(getixs(eins), xs, getiy(eins), size_dict, dy, i), length(xs)) | ||
return res | ||
end | ||
return copy(cache.content), back_propagate(bprule, code, cache, ȳ, size_dict) | ||
end | ||
|
||
""" | ||
cost_and_gradient(code, xs, ȳ) | ||
Compute the cost and the gradients w.r.t the input tensors `xs`. | ||
### Arguments | ||
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`. | ||
- `xs`: The input tensors. | ||
- `ȳ`: The message to back-propagate. Default is `1`. | ||
### Returns | ||
- `cost`: The cost of the contraction. | ||
- `grads`: The gradients w.r.t the input tensors. | ||
""" | ||
function cost_and_gradient(code, xs, ȳ = nothing) | ||
if ȳ === nothing | ||
ȳ = init_gradient(code, xs) | ||
@assert ndims(ȳ) == 0 "The output must be a scalar! Or you need to feed the gradient manually. Got: $(ndims(ȳ))!" | ||
end | ||
cost, tree = gradient_tree(code, xs, ȳ) | ||
# extract the gradients on leaves (i.e. the input tensors). | ||
return cost, extract_leaves(code, tree) | ||
end | ||
|
||
# initialize `y̅` as `1`. Note we always start from `L̅ := 1`. | ||
function init_gradient(code, xs) | ||
size_dict = get_size_dict!(getixsv(code), xs, Dict{labeltype(code), Int}()) | ||
output_size = getindex.(Ref(size_dict), getiyv(code)) | ||
ȳ = get_output_array(xs, output_size) | ||
return fill!(ȳ, one(eltype(ȳ))) | ||
end | ||
|
||
# since slicing is not supported, we forward it to NestedEinsum. | ||
extract_leaves(code::SlicedEinsum, cache::CacheTree) = extract_leaves(code.eins, cache) | ||
|
||
# extract gradients on leaf nodes. | ||
function extract_leaves(code::NestedEinsum, cache::CacheTree) | ||
res = Vector{Any}(undef, length(getixsv(code))) | ||
return extract_leaves!(code, cache, res) | ||
end | ||
|
||
function extract_leaves!(code, cache, res) | ||
if isleaf(code) | ||
# extract | ||
res[tensorindex(code)] = cache.content | ||
else | ||
# resurse deeper | ||
for (subcode, sib) in zip(siblings(code), cache.siblings) | ||
extract_leaves!(subcode, sib, res) | ||
end | ||
end | ||
return res | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
using OMEinsum, Test, Zygote | ||
|
||
@testset "bp check" begin | ||
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2) | ||
cost0 = ein"(ij, jk), ki->"(A, B, C)[] | ||
zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C) | ||
cost, mg = OMEinsum.cost_and_gradient(ein"(ij, jk), ki->", (A, B, C)) | ||
@test cost[] ≈ cost0 | ||
@test all(zg .≈ mg) | ||
|
||
code = optimize_code(ein"ij, jk, ki->", uniformsize(ein"ij, jk, ki->", 2), TreeSA()) | ||
cost0 = code(A, B, C)[] | ||
zg = Zygote.gradient((a, b, c)->code(a, b, c)[], A, B, C) | ||
cost, mg = OMEinsum.cost_and_gradient(code, (A, B, C)) | ||
@test cost[] ≈ cost0 | ||
@test all(zg .≈ mg) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters