Skip to content

Commit

Permalink
Type-generic rewrite (#88)
Browse files Browse the repository at this point in the history
* Type-generic rewrite

* More work

* Tests passing

* Reactivate workflows

* Fix docs

* Format

* AutoZero

* Fix broken backends

* Better docs table

* Reinsert preparation

* Tests passing with bangbang

* Ready for merge
  • Loading branch information
gdalle authored Mar 23, 2024
1 parent 0a29640 commit d0903c5
Show file tree
Hide file tree
Showing 76 changed files with 1,813 additions and 3,787 deletions.
5 changes: 1 addition & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
/docs/build/
/docs/src/index.md

/Manifest.toml
/docs/Manifest.toml
/test/Manifest.toml
/benchmark/Manifest.toml
Manifest.toml

*.csv
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Expand All @@ -40,6 +41,7 @@ DifferentiationInterfaceFastDifferentiationExt = [
"RuntimeGeneratedFunctions",
]
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = ["DiffResults", "ForwardDiff"]
DifferentiationInterfaceJETExt = ["JET"]
DifferentiationInterfacePolyesterForwardDiffExt = [
Expand All @@ -63,6 +65,7 @@ Enzyme = "0.11"
FastDifferentiation = "0.3"
FillArrays = "1"
FiniteDiff = "2.22"
FiniteDifferences = "0.12"
ForwardDiff = "0.10"
JET = "0.8"
LinearAlgebra = "1"
Expand Down
49 changes: 15 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,42 @@ An interface to various automatic differentiation backends in Julia.

This package provides a backend-agnostic syntax to differentiate functions of the following types:

- **Allocating**: `f(x) = y` where `x` and `y` can be real numbers or abstract arrays
- **Mutating**: `f!(y, x) = nothing` where `y` is an abstract array and `x` can be a real number or an abstract array
- **allocating**: `f(x) = y`
- **mutating**: `f!(y, x) = nothing`

## Features

- First and second order operators
- In-place and out-of-place differentiation
- Preparation mechanism (e.g. to create a config or tape)
- Cross-backend testing and benchmarking utilities
- Thorough validation on standard inputs and outputs (scalars, vectors, matrices)

## Compatibility

We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):
We support most of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

| Backend | Object |
| :------------------------------------------------------------------------------ | :----------------------------------------------------------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Enzyme.Forward)` or `AutoEnzyme(Enzyme.Reverse)` |
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` |
| [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl) | `AutoFiniteDifferences(fdm)` |
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize)` |
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |
| [Tracker.jl](https://github.com/FluxML/Tracker.jl) | `AutoTracker()` |
| [Tracker.jl](https://github.com/FluxML/Tracker.jl) | `AutoTracker()` |
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |

We also provide additional backends:
We also provide one additional backend:

| Backend | Object |
| :------------------------------------------------------------------------------- | :-------------------------- |
| [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl) | `AutoFastDifferentiation()` |

## Example

Setup:

```jldoctest readme
julia> import ADTypes, ForwardDiff
Expand All @@ -48,38 +55,12 @@ julia> using DifferentiationInterface
julia> backend = ADTypes.AutoForwardDiff();
julia> f(x) = sum(abs2, x);
```

Out-of-place gradient:

```jldoctest readme
julia> value_and_gradient(backend, f, [1., 2., 3.])
(14.0, [2.0, 4.0, 6.0])
```

In-place gradient:

```jldoctest readme
julia> grad = zeros(3);
julia> value_and_gradient!(grad, backend, f, [1., 2., 3.])
julia> value_and_gradient(f, backend, [1., 2., 3.])
(14.0, [2.0, 4.0, 6.0])
julia> grad
3-element Vector{Float64}:
2.0
4.0
6.0
```

## Related packages

- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) is the original inspiration for DifferentiationInterface.jl.
- [AutoDiffOperators.jl](https://github.com/oschulz/AutoDiffOperators.jl) is an attempt to bridge ADTypes.jl with AbstractDifferentiation.jl.

## Roadmap

Goals for future releases:

- optimize performance for each backend
- define user-facing functions to test and benchmark backends against each other
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand Down
46 changes: 14 additions & 32 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,13 @@ using Diffractor: Diffractor
using Enzyme: Enzyme
using FastDifferentiation: FastDifferentiation
using FiniteDiff: FiniteDiff
using FiniteDifferences: FiniteDifferences
using ForwardDiff: ForwardDiff
using PolyesterForwardDiff: PolyesterForwardDiff
using ReverseDiff: ReverseDiff
using Tracker: Tracker
using Zygote: Zygote

ChainRulesCoreExt = get_extension(DI, :DifferentiationInterfaceChainRulesCoreExt)
DiffractorExt = get_extension(DI, :DifferentiationInterfaceDiffractorExt)
EnzymeExt = get_extension(DI, :DifferentiationInterfaceEnzymeExt)
FastDifferentiationExt = get_extension(DI, :DifferentiationInterfaceFastDifferentiationExt)
FiniteDiffExt = get_extension(DI, :DifferentiationInterfaceFiniteDiffExt)
ForwardDiffExt = get_extension(DI, :DifferentiationInterfaceForwardDiffExt)
PolyesterForwardDiffExt = get_extension(
DI, :DifferentiationInterfacePolyesterForwardDiffExt
)
ReverseDiffExt = get_extension(DI, :DifferentiationInterfaceReverseDiffExt)
TrackerExt = get_extension(DI, :DifferentiationInterfaceTrackerExt)
ZygoteExt = get_extension(DI, :DifferentiationInterfaceZygoteExt)

DocMeta.setdocmeta!(
DifferentiationInterface,
:DocTestSetup,
:(using DifferentiationInterface, ADTypes);
recursive=true,
)

open(joinpath(@__DIR__, "src", "index.md"), "w") do io
println(
io,
Expand All @@ -60,16 +41,17 @@ makedocs(;
ADTypes,
DifferentiationInterface,
DifferentiationInterface.DifferentiationTest,
ChainRulesCoreExt,
DiffractorExt,
EnzymeExt,
FastDifferentiationExt,
FiniteDiffExt,
ForwardDiffExt,
PolyesterForwardDiffExt,
ReverseDiffExt,
TrackerExt,
ZygoteExt,
get_extension(DI, :DifferentiationInterfaceChainRulesCoreExt),
get_extension(DI, :DifferentiationInterfaceDiffractorExt),
get_extension(DI, :DifferentiationInterfaceEnzymeExt),
get_extension(DI, :DifferentiationInterfaceFastDifferentiationExt),
get_extension(DI, :DifferentiationInterfaceFiniteDiffExt),
get_extension(DI, :DifferentiationInterfaceFiniteDifferencesExt),
get_extension(DI, :DifferentiationInterfaceForwardDiffExt),
get_extension(DI, :DifferentiationInterfacePolyesterForwardDiffExt),
get_extension(DI, :DifferentiationInterfaceReverseDiffExt),
get_extension(DI, :DifferentiationInterfaceTrackerExt),
get_extension(DI, :DifferentiationInterfaceZygoteExt),
],
authors="Guillaume Dalle, Adrian Hill",
sitename="DifferentiationInterface.jl",
Expand All @@ -80,9 +62,9 @@ makedocs(;
),
pages=[
"Home" => "index.md", #
"getting_started.md", #
"api.md", #
"overview.md", #
"backends.md", #
"api.md", #
"developer.md",
],
warnonly=:missing_docs, # missing docs for ADTypes.jl are normal
Expand Down
37 changes: 21 additions & 16 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,60 @@ CollapsedDocStrings = true
DifferentiationInterface
```

## Scalar to scalar
## Derivative

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["src/derivative.jl", "src/second_derivative.jl"]
Pages = ["src/derivative.jl"]
```

## Scalar to array
## Gradient

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["multiderivative.jl"]
Pages = ["gradient.jl"]
```

## Array to scalar
## Jacobian

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["gradient.jl", "hessian.jl", "hessian_vector_product.jl"]
Pages = ["jacobian.jl"]
```

## Array to array
## Second order

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["jacobian.jl"]
Pages = ["second_order.jl", "second_derivative.jl", "hessian.jl", "hvp.jl"]
```

## Lower-level operators
## Primitives

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["pushforward.jl", "pullback.jl"]
```

## Backend queries

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["backends.jl"]
```

## Preparation

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["prepare.jl"]
```

## Backend queries
## Testing & benchmarking

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["backends.jl"]
Modules = [DifferentiationTest]
Private = false
```

## Internals
Expand All @@ -65,12 +72,10 @@ This is not part of the public API.
```@autodocs
Modules = [DifferentiationInterface]
Public = false
Order = [:function, :type]
```

## Testing

This is not part of the public API.

```@autodocs
Modules = [DifferentiationTest]
Public = false
```
42 changes: 5 additions & 37 deletions docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ CollapsedDocStrings = true
using ADTypes, DifferentiationInterface
using DifferentiationInterface.DifferentiationTest: backend_string
import Markdown
import Chairmarks, DataFrames
import Enzyme, FastDifferentiation, FiniteDiff, ForwardDiff, PolyesterForwardDiff, ReverseDiff, Tracker, Zygote
import Enzyme, FastDifferentiation, FiniteDiff, FiniteDifferences, ForwardDiff, PolyesterForwardDiff, ReverseDiff, Tracker, Zygote
function all_backends()
return [
Expand All @@ -16,6 +15,7 @@ function all_backends()
AutoEnzyme(Enzyme.Reverse),
AutoFastDifferentiation(),
AutoFiniteDiff(),
AutoFiniteDifferences(FiniteDifferences.central_fdm(5, 1)),
AutoForwardDiff(),
AutoPolyesterForwardDiff(; chunksize=2),
AutoReverseDiff(),
Expand Down Expand Up @@ -47,6 +47,7 @@ AutoEnzyme
AutoForwardDiff
AutoForwardDiff()
AutoFiniteDiff
AutoFiniteDifferences
AutoPolyesterForwardDiff
AutoPolyesterForwardDiff()
AutoReverseDiff
Expand All @@ -58,7 +59,6 @@ We also provide a few of our own:

```@docs
AutoFastDifferentiation
SecondOrder
```

## Availability
Expand All @@ -74,7 +74,7 @@ end # hide
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
```

## [Mutation support](@id backend_mutation_behavior)
## Mutation support

All backends are compatible with allocating functions `f(x) = y`.
Only some are compatible with mutating functions `f!(y, x) = nothing`.
Expand All @@ -89,39 +89,6 @@ end # hide
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
```

## Second order

For second-order differentiation, you can either

- combine a pair `(backend_inner, backend_outer)` of inner and outer backends into a [`SecondOrder`](@ref) object
- use a single backend `backend`, which amounts to `backend_inner = backend_outer = backend`

In Hessian computations, the most efficient combination is often forward-over-reverse, i.e. `SecondOrder(reverse_backend, forward_backend)`.

!!! info
Many backend combinations will fail for second order.
Some because of our implementation, and some because the outer backend cannot differentiate through code generated by the inner backend.

You can use [`check_hessian`](@ref) to find working combinations, like we did below (Enzyme was skipped due to compilation time):

```@example backends
header = "| Inner \\ Outer |" # hide
subheader = "|---|" # hide
for bo in all_backends_without_enzyme() # hide
global header *= " $(backend_string(bo)) |" # hide
global subheader *= "---|" # hide
end # hide
rows = map(all_backends_without_enzyme()) do bi # hide
@info "Generating hessian row for $(backend_string(bi))" # hide
row = "| $(backend_string(bi)) |" # hide
for bo in all_backends_without_enzyme() # hide
row *= " $(check_hessian(SecondOrder(bi, bo)) ? '✓' : '✗') |" # hide
end # hide
row # hide
end # hide
Markdown.parse(join(vcat(header, subheader, rows...), "\n") * "\n") # hide
```

## Package extensions

```@meta
Expand All @@ -137,6 +104,7 @@ Modules = [
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceEnzymeExt),
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceFastDifferentiationExt),
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceFiniteDiffExt),
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceFiniteDifferencesExt),
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceForwardDiffExt),
Base.get_extension(DifferentiationInterface, :DifferentiationInterfacePolyesterForwardDiffExt),
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceReverseDiffExt),
Expand Down
Loading

0 comments on commit d0903c5

Please sign in to comment.