Skip to content

Commit

Permalink
Hessian vector product and Gauss-Newton vector product utilities.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 608993041
  • Loading branch information
vroulet authored and OptaxDev committed Feb 22, 2024
1 parent bd3e5c8 commit 3fa5d27
Show file tree
Hide file tree
Showing 8 changed files with 813 additions and 198 deletions.
25 changes: 13 additions & 12 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,22 @@ Second Order Optimization
.. currentmodule:: optax.second_order

.. autosummary::
fisher_diag
hessian_diag
hvp
hvp_call
make_gnvp_fn
make_hvp_fn

Fisher diagonal
~~~~~~~~~~~~~~~
.. autofunction:: fisher_diag
Compute Hessian vector product (hvp) directly
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: hvp_call

Hessian diagonal
~~~~~~~~~~~~~~~~
.. autofunction:: hessian_diag
Instantiate Gauss-Newton vector product (gnvp) function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: make_gnvp_fn

Instantiate Hessian vector product (hvp) function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: make_hvp_fn

Hessian vector product
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: hvp


Tree
Expand Down
11 changes: 8 additions & 3 deletions optax/second_order/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# ==============================================================================
"""The second order optimisation sub-package."""

from optax.second_order._fisher import fisher_diag
from optax.second_order._hessian import hessian_diag
from optax.second_order._hessian import hvp
from optax.second_order._deprecated import fisher_diag
from optax.second_order._deprecated import hessian_diag
from optax.second_order._deprecated import hvp

from optax.second_order._oracles import hvp_call
from optax.second_order._oracles import make_gnvp_fn
from optax.second_order._oracles import make_hvp_fn

30 changes: 0 additions & 30 deletions optax/second_order/_base.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for computing diagonals of the Hessian wrt to a set of parameters.
Computing the Hessian for neural networks is typically intractible due to the
quadratic memory requirements. Solving for the diagonal can be done cheaply,
with sub-quadratic memory requirements.
"""Deprecated utilities kept for backward compatibility.
"""

from typing import Any
import abc
from typing import Any, Protocol

import jax
from jax import flatten_util
import jax.numpy as jnp

from optax.second_order import _base


def _ravel(p: Any) -> jax.Array:
return flatten_util.ravel_pytree(p)[0]


class LossFn(Protocol):
"""A loss function to be optimized."""

@abc.abstractmethod
def __call__(
self, params: Any, inputs: jax.Array, targets: jax.Array
) -> jax.Array:
...


def hvp(
loss: _base.LossFn,
loss: LossFn,
v: jax.Array,
params: Any,
inputs: jax.Array,
targets: jax.Array,
) -> jax.Array:
"""Performs an efficient vector-Hessian (of `loss`) product.
.. deprecated: 0.2
Args:
loss: the loss function.
v: a vector of size `ravel(params)`.
Expand All @@ -58,13 +65,15 @@ def hvp(


def hessian_diag(
loss: _base.LossFn,
loss: LossFn,
params: Any,
inputs: jax.Array,
targets: jax.Array,
) -> jax.Array:
"""Computes the diagonal hessian of `loss` at (`inputs`, `targets`).
.. deprecated: 0.2
Args:
loss: the loss function.
params: model parameters.
Expand All @@ -78,3 +87,27 @@ def hessian_diag(
vs = jnp.eye(_ravel(params).size)
comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets)))
return jax.vmap(comp)(vs)


def fisher_diag(
negative_log_likelihood: LossFn,
params: Any,
inputs: jax.Array,
targets: jax.Array,
) -> jax.Array:
"""Computes the diagonal of the (observed) Fisher information matrix.
Args:
negative_log_likelihood: the negative log likelihood function with expected
signature `loss = fn(params, inputs, targets)`.
params: model parameters.
inputs: inputs at which `negative_log_likelihood` is evaluated.
targets: targets at which `negative_log_likelihood` is evaluated.
Returns:
An Array corresponding to the product to the Hessian of
`negative_log_likelihood` evaluated at `(params, inputs, targets)`.
"""
return jnp.square(
_ravel(jax.grad(negative_log_likelihood)(params, inputs, targets))
)
55 changes: 0 additions & 55 deletions optax/second_order/_fisher.py

This file was deleted.

88 changes: 0 additions & 88 deletions optax/second_order/_hessian_test.py

This file was deleted.

Loading

0 comments on commit 3fa5d27

Please sign in to comment.