Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hessian vector product and Gauss-Newton vector product utilities. #817

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,22 @@ Second Order Optimization
.. currentmodule:: optax.second_order

.. autosummary::
fisher_diag
hessian_diag
hvp
hvp_call
make_gnvp_fn
make_hvp_fn

Fisher diagonal
~~~~~~~~~~~~~~~
.. autofunction:: fisher_diag
Compute Hessian vector product (hvp) directly
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: hvp_call

Hessian diagonal
~~~~~~~~~~~~~~~~
.. autofunction:: hessian_diag
Instantiate Gauss-Newton vector product (gnvp) function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: make_gnvp_fn

Instantiate Hessian vector product (hvp) function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: make_hvp_fn

Hessian vector product
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: hvp


Tree
Expand Down
11 changes: 8 additions & 3 deletions optax/second_order/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# ==============================================================================
"""The second order optimisation sub-package."""

from optax.second_order._fisher import fisher_diag
from optax.second_order._hessian import hessian_diag
from optax.second_order._hessian import hvp
from optax.second_order._deprecated import fisher_diag
from optax.second_order._deprecated import hessian_diag
from optax.second_order._deprecated import hvp

from optax.second_order._oracles import hvp_call
from optax.second_order._oracles import make_gnvp_fn
from optax.second_order._oracles import make_hvp_fn

30 changes: 0 additions & 30 deletions optax/second_order/_base.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for computing diagonals of the Hessian wrt to a set of parameters.

Computing the Hessian for neural networks is typically intractible due to the
quadratic memory requirements. Solving for the diagonal can be done cheaply,
with sub-quadratic memory requirements.
"""Deprecated utilities kept for backward compatibility.
"""

from typing import Any
import abc
from typing import Any, Protocol

import jax
from jax import flatten_util
import jax.numpy as jnp

from optax.second_order import _base


def _ravel(p: Any) -> jax.Array:
return flatten_util.ravel_pytree(p)[0]


class LossFn(Protocol):
"""A loss function to be optimized."""

@abc.abstractmethod
def __call__(
self, params: Any, inputs: jax.Array, targets: jax.Array
) -> jax.Array:
...


def hvp(
loss: _base.LossFn,
loss: LossFn,
v: jax.Array,
params: Any,
inputs: jax.Array,
targets: jax.Array,
) -> jax.Array:
"""Performs an efficient vector-Hessian (of `loss`) product.

.. deprecated: 0.2

Args:
loss: the loss function.
v: a vector of size `ravel(params)`.
Expand All @@ -58,13 +65,15 @@ def hvp(


def hessian_diag(
loss: _base.LossFn,
loss: LossFn,
params: Any,
inputs: jax.Array,
targets: jax.Array,
) -> jax.Array:
"""Computes the diagonal hessian of `loss` at (`inputs`, `targets`).

.. deprecated: 0.2

Args:
loss: the loss function.
params: model parameters.
Expand All @@ -78,3 +87,27 @@ def hessian_diag(
vs = jnp.eye(_ravel(params).size)
comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets)))
return jax.vmap(comp)(vs)


def fisher_diag(
negative_log_likelihood: LossFn,
params: Any,
inputs: jax.Array,
targets: jax.Array,
) -> jax.Array:
"""Computes the diagonal of the (observed) Fisher information matrix.

Args:
negative_log_likelihood: the negative log likelihood function with expected
signature `loss = fn(params, inputs, targets)`.
params: model parameters.
inputs: inputs at which `negative_log_likelihood` is evaluated.
targets: targets at which `negative_log_likelihood` is evaluated.

Returns:
An Array corresponding to the product to the Hessian of
`negative_log_likelihood` evaluated at `(params, inputs, targets)`.
"""
return jnp.square(
_ravel(jax.grad(negative_log_likelihood)(params, inputs, targets))
)
55 changes: 0 additions & 55 deletions optax/second_order/_fisher.py

This file was deleted.

88 changes: 0 additions & 88 deletions optax/second_order/_hessian_test.py

This file was deleted.

Loading
Loading