Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better backend docs #76

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ Modules = [DifferentiationInterface]
Pages = ["prepare.jl"]
```

## Backend queries

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["backends.jl"]
```

## Internals

This is not part of the public API.
Expand Down
97 changes: 75 additions & 22 deletions docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,39 @@
CollapsedDocStrings = true
```

# Backends

```@docs
available
```@setup backends
using ADTypes, DifferentiationInterface
using DifferentiationInterface.DifferentiationTest: backend_string
import Markdown
import Chairmarks, DataFrames
import Enzyme, FastDifferentiation, FiniteDiff, ForwardDiff, PolyesterForwardDiff, ReverseDiff, Tracker, Zygote

function all_backends()
return [
AutoDiffractor(),
AutoEnzyme(Enzyme.Forward),
AutoEnzyme(Enzyme.Reverse),
AutoFastDifferentiation(),
AutoFiniteDiff(),
AutoForwardDiff(),
AutoPolyesterForwardDiff(; chunksize=2),
AutoReverseDiff(),
AutoTracker(),
AutoZygote(),
]
end

function all_backends_without_enzyme()
return filter(b -> !isa(b, AutoEnzyme), all_backends())
end
```

# Backends

## Types

Most backend choices are defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).


!!! warning
Only the backends listed here are supported by DifferentiationInterface.jl, even though ADTypes.jl defines more.

Expand All @@ -37,36 +59,67 @@ AutoFastDifferentiation
SecondOrder
```

## Availability

You can use [`available`](@ref) to verify whether a given backend is loaded, like we did below:

```@example backends
header = "| Backend | available |" # hide
subheader = "|---|---|" # hide
rows = map(all_backends()) do backend # hide
"| `$(backend_string(backend))` | $(available(backend) ? '✓' : '✗') |" # hide
end # hide
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
```

## [Mutation support](@id backend_mutation_behavior)

All backends are compatible with allocating functions `f(x) = y`. Only some are compatible with mutating functions `f!(y, x) = nothing`.

| Backend | Mutating functions |
| :-------------------------------------- | ------------------ |
| `AutoChainRules(ruleconfig)` | ✗ |
| `AutoDiffractor()` | ✗ |
| `AutoEnzyme(Enzyme.Forward)` | ✓ |
| `AutoEnzyme(Enzyme.Reverse)` | ✓ |
| `AutoFiniteDiff()` | ✓ |
| `AutoForwardDiff()` | ✓ |
| `AutoPolyesterForwardDiff(; chunksize)` | ✓ |
| `AutoReverseDiff()` | ✓ |
| `AutoTracker()` | ✗ |
| `AutoZygote()` | ✗ |
All backends are compatible with allocating functions `f(x) = y`.
Only some are compatible with mutating functions `f!(y, x) = nothing`.
You can use [`supports_mutation`](@ref) to check that feature, like we did below:

```@example backends
header = "| Backend | mutation |" # hide
subheader = "|---|---|" # hide
rows = map(all_backends()) do backend # hide
"| `$(backend_string(backend))` | $(supports_mutation(backend) ? '✓' : '✗') |" # hide
end # hide
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
```

## Second order

For second-order differentiation, you can either

- use a single backend
- combine a pair of backends into a [`SecondOrder`](@ref) object
- combine a pair `(backend_inner, backend_outer)` of inner and outer backends into a [`SecondOrder`](@ref) object
- use a single backend `backend`, which amounts to `backend_inner = backend_outer = backend`

In Hessian computations, the most efficient combination is often forward-over-reverse, i.e. `SecondOrder(reverse_backend, forward_backend)`.

!!! danger
Many backend combinations will fail for second order.
Some because of our implementation, and some because the outer backend cannot differentiate through code generated by the inner backend.
Be ready to experiment, or check our test suite to see which ones we have vetted.

You can use [`supports_hessian`](@ref) to find working combinations, like we did below (Enzyme is skipped here due to compilation overhead):

```@example backends
header = "| Inner \\ Outer |" # hide
subheader = "|---|" # hide
for bo in all_backends_without_enzyme() # hide
global header *= " $(backend_string(bo)) |" # hide
global subheader *= "---|" # hide
end # hide
rows = map(all_backends_without_enzyme()) do bi # hide
@info "Generating hessian row for $(backend_string(bi))" # hide
row = "| $(backend_string(bi)) |" # hide
for bo in all_backends_without_enzyme() # hide
row *= " $(supports_hessian(SecondOrder(bi, bo)) ? '✓' : '✗') |" # hide
end # hide
row # hide
end # hide
Markdown.parse(join(vcat(header, subheader, rows...), "\n") * "\n") # hide
```


## Package extensions

Expand Down
4 changes: 2 additions & 2 deletions src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ include("second_derivative.jl")
include("hessian_vector_product.jl")
include("hessian.jl")

include("available.jl")
include("backends.jl")

export AutoFastDifferentiation
export SecondOrder
Expand Down Expand Up @@ -83,7 +83,7 @@ export prepare_second_derivative
export prepare_hessian
export prepare_hessian_vector_product

export available
export available, supports_mutation, supports_hessian

# submodules
include("DifferentiationTest/DifferentiationTest.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/DifferentiationTest/pretty.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pretty(::AutoFastDifferentiation) = "FastDifferentiation"
pretty(::AutoFiniteDiff) = "FiniteDiff"
pretty(::AutoForwardDiff) = "ForwardDiff"
pretty(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff"
pretty(b::AutoReverseDiff) = "ReverseDiff($(b.compile))"
pretty(b::AutoReverseDiff) = "ReverseDiff$(b.compile ? "{compiled}" : "")"
pretty(::AutoTracker) = "Tracker"
pretty(::AutoZygote) = "Zygote"
pretty(b::AbstractADType) = string(b)
Expand Down
20 changes: 0 additions & 20 deletions src/available.jl

This file was deleted.

50 changes: 50 additions & 0 deletions src/backends.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
available(backend)

Check whether `backend` is available by trying a scalar-to-scalar derivative.
Might take a while due to compilation time.
"""
function available(backend::AbstractADType)
try
derivative(backend, identity, 1.0)
return true
catch e
if e isa MethodError
return false
else
throw(e)
end
end
end

available(backend::SecondOrder) = available(inner(backend)) && available(outer(backend))

"""
supports_mutation(backend)

Check whether `backend` supports differentiation of mutating functions by trying a jacobian.
Might take a while due to compilation time.
"""
function supports_mutation(backend::AbstractADType)
try
value_and_jacobian!([0.0], [0.0;;], backend, copyto!, [1.0])
return true
catch e
return false
end
end

"""
supports_hessian(backend)

Check whether `backend` supports second order differentiation by trying a hessian.
Might take a while due to compilation time.
"""
function supports_hessian(backend::AbstractADType)
try
hessian(backend, sum, [1.0])
return true
catch e
return false
end
end
1 change: 1 addition & 0 deletions test/chainrules_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ using JET: JET
using Test

@test available(AutoChainRules(ZygoteRuleConfig()))
@test !supports_mutation(AutoChainRules(ZygoteRuleConfig()))

test_operators(AutoChainRules(ZygoteRuleConfig()); second_order=false, type_stability=false);
1 change: 1 addition & 0 deletions test/diffractor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ using JET: JET
using Test

@test available(AutoDiffractor())
@test !supports_mutation(AutoDiffractor())

test_operators(AutoDiffractor(); second_order=false, type_stability=false);
1 change: 1 addition & 0 deletions test/enzyme_forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using JET: JET
using Test

@test available(AutoEnzyme(Enzyme.Forward))
@test supports_mutation(AutoEnzyme(Enzyme.Forward))

test_operators(
AutoEnzyme(Enzyme.Forward); second_order=false, excluded=[:jacobian_allocating]
Expand Down
1 change: 1 addition & 0 deletions test/enzyme_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ using JET: JET
using Test

@test available(AutoEnzyme(Enzyme.Reverse))
@test supports_mutation(AutoEnzyme(Enzyme.Reverse))

test_operators(AutoEnzyme(Enzyme.Reverse); second_order=false);
1 change: 1 addition & 0 deletions test/finitediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using JET: JET
using Test

@test available(AutoFiniteDiff())
@test supports_mutation(AutoFiniteDiff())

test_operators(AutoFiniteDiff(); second_order=false, excluded=[:jacobian_allocating]);
test_operators(AutoFiniteDiff(), [:jacobian_allocating]; type_stability=false);
2 changes: 2 additions & 0 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ using JET: JET
using Test

@test available(AutoForwardDiff())
@test supports_mutation(AutoForwardDiff())
@test supports_hessian(AutoForwardDiff())

test_operators(AutoForwardDiff(; chunksize=2));
2 changes: 2 additions & 0 deletions test/polyesterforwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ using JET: JET
using Test

@test available(AutoPolyesterForwardDiff(; chunksize=2))
@test supports_mutation(AutoPolyesterForwardDiff(; chunksize=2))
@test supports_hessian(AutoPolyesterForwardDiff(; chunksize=2))

test_operators(
AutoPolyesterForwardDiff(; chunksize=2);
Expand Down
1 change: 1 addition & 0 deletions test/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using JET: JET
using Test

@test available(AutoReverseDiff())
@test supports_mutation(AutoReverseDiff())

test_operators(AutoReverseDiff(); second_order=false, type_stability=false);
test_operators(AutoReverseDiff(; compile=true); second_order=false, type_stability=false);
1 change: 1 addition & 0 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using JET: JET
using Test

@test available(AutoTracker())
@test !supports_mutation(AutoTracker())

test_operators(
AutoTracker();
Expand Down
1 change: 1 addition & 0 deletions test/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ using JET: JET
using Test

@test available(AutoZygote())
@test !supports_mutation(AutoZygote())

test_operators(AutoZygote(); second_order=false, type_stability=false);
Loading