Skip to content

Commit

Permalink
docs: add 3rd order AD example using Reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 4, 2025
1 parent 219eeba commit a71f40a
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pages = [
"manual/distributed_utils.md",
"manual/nested_autodiff.md",
"manual/compiling_lux_models.md",
"manual/exporting_to_jax.md",
"manual/nested_autodiff_reactant.md"
],
"API Reference" => [
"Lux" => [
Expand Down
4 changes: 4 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ export default defineConfig({
text: "Exporting Lux Models to Jax",
link: "/manual/exporting_to_jax",
},
{
text: "Nested AutoDiff",
link: "/manual/nested_autodiff_reactant",
}
],
},
{
Expand Down
14 changes: 7 additions & 7 deletions docs/src/manual/nested_autodiff.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# [Nested Automatic Differentiation](@id nested_autodiff)

!!! note

This is a relatively new feature in Lux, so there might be some rough edges. If you
encounter any issues, please let us know by opening an issue on the
[GitHub repository](https://github.com/LuxDL/Lux.jl).

In this manual, we will explore how to use automatic differentiation (AD) inside your layers
or loss functions and have Lux automatically switch the AD backend with a faster one when
needed.

!!! tip
!!! tip "Reactant Support"

Reactant + Lux natively supports Nested AD (even higher dimensions). If you are using
Reactant, please see the [Nested AD with Reactant](@ref nested_autodiff_reactant)
manual.

!!! tip "Disabling Nested AD Switching"

Don't wan't Lux to do this switching for you? You can disable it by setting the
`automatic_nested_ad_switching` Preference to `false`.
Expand Down
66 changes: 66 additions & 0 deletions docs/src/manual/nested_autodiff_reactant.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant)

We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614).

```@example nested_ad_reactant
using Reactant, Enzyme, Lux, Random, LinearAlgebra
const xdev = reactant_device()
const cdev = cpu_device()
# XXX: We need to be able to compile this with a for-loop else tracing time will scale
# proportionally to the number of elements in the input.
function ∇potential(potential, x)
dxs = onehot(x)
∇p = similar(x)
for i in eachindex(dxs)
dxᵢ = dxs[i]
res = only(Enzyme.autodiff(
Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ)
))
@allowscalar ∇p[i] = res[i]
end
return ∇p
end
function ∇²potential(potential, x)
dxs = onehot(x)
∇²p = similar(x)
for i in eachindex(dxs)
dxᵢ = dxs[i]
res = only(Enzyme.autodiff(
Enzyme.set_abi(Forward, Reactant.ReactantABI),
∇potential, Const(potential), Duplicated(x, dxᵢ)
))
@allowscalar ∇²p[i] = res[i]
end
return ∇²p
end
struct PotentialNet{P} <: Lux.AbstractLuxWrapperLayer{:potential}
potential::P
end
function (potential::PotentialNet)(x, ps, st)
pnet = StatefulLuxLayer{true}(potential.potential, ps, st)
return ∇²potential(pnet, x), pnet.st
end
model = PotentialNet(Dense(5 => 5, gelu))
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
x_ra = randn(Float32, 5, 3) |> xdev
model_compiled = @compile model(x_ra, ps, st)
model_compiled(x_ra, ps, st)
sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))
function enzyme_gradient(model, x, ps, st)
return Enzyme.gradient(
Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st)
)
end
@jit enzyme_gradient(model, x_ra, ps, st)
```

0 comments on commit a71f40a

Please sign in to comment.