From 3fa5d27d8d7c57590cafcab3d5b2ce28de3402e3 Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Wed, 21 Feb 2024 07:54:28 -0800 Subject: [PATCH] Hessian vector product and Gauss-Newton vector product utilities. PiperOrigin-RevId: 608993041 --- docs/api/utilities.rst | 25 +- optax/second_order/__init__.py | 11 +- optax/second_order/_base.py | 30 -- .../{_hessian.py => _deprecated.py} | 53 ++- optax/second_order/_fisher.py | 55 --- optax/second_order/_hessian_test.py | 88 ---- optax/second_order/_oracles.py | 443 ++++++++++++++++++ optax/second_order/_oracles_test.py | 306 ++++++++++++ 8 files changed, 813 insertions(+), 198 deletions(-) delete mode 100644 optax/second_order/_base.py rename optax/second_order/{_hessian.py => _deprecated.py} (65%) delete mode 100644 optax/second_order/_fisher.py delete mode 100644 optax/second_order/_hessian_test.py create mode 100644 optax/second_order/_oracles.py create mode 100644 optax/second_order/_oracles_test.py diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index 1dbc1bec9..6e5fb294a 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -49,21 +49,22 @@ Second Order Optimization .. currentmodule:: optax.second_order .. autosummary:: - fisher_diag - hessian_diag - hvp + hvp_call + make_gnvp_fn + make_hvp_fn -Fisher diagonal -~~~~~~~~~~~~~~~ -.. autofunction:: fisher_diag +Compute Hessian vector product (hvp) directly +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: hvp_call -Hessian diagonal -~~~~~~~~~~~~~~~~ -.. autofunction:: hessian_diag +Instantiate Gauss-Newton vector product (gnvp) function +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: make_gnvp_fn + +Instantiate Hessian vector product (hvp) function +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: make_hvp_fn -Hessian vector product -~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: hvp Tree diff --git a/optax/second_order/__init__.py b/optax/second_order/__init__.py index 1df9fae58..683ce83df 100644 --- a/optax/second_order/__init__.py +++ b/optax/second_order/__init__.py @@ -14,6 +14,11 @@ # ============================================================================== """The second order optimisation sub-package.""" -from optax.second_order._fisher import fisher_diag -from optax.second_order._hessian import hessian_diag -from optax.second_order._hessian import hvp +from optax.second_order._deprecated import fisher_diag +from optax.second_order._deprecated import hessian_diag +from optax.second_order._deprecated import hvp + +from optax.second_order._oracles import hvp_call +from optax.second_order._oracles import make_gnvp_fn +from optax.second_order._oracles import make_hvp_fn + diff --git a/optax/second_order/_base.py b/optax/second_order/_base.py deleted file mode 100644 index ac0e5301a..000000000 --- a/optax/second_order/_base.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Base types for the second order sub-package.""" - -import abc -from typing import Any, Protocol - -import jax - - -class LossFn(Protocol): - """A loss function to be optimized.""" - - @abc.abstractmethod - def __call__( - self, params: Any, inputs: jax.Array, targets: jax.Array - ) -> jax.Array: - ... diff --git a/optax/second_order/_hessian.py b/optax/second_order/_deprecated.py similarity index 65% rename from optax/second_order/_hessian.py rename to optax/second_order/_deprecated.py index 688fcb700..57cc8ab91 100644 --- a/optax/second_order/_hessian.py +++ b/optax/second_order/_deprecated.py @@ -12,28 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions for computing diagonals of the Hessian wrt to a set of parameters. - -Computing the Hessian for neural networks is typically intractible due to the -quadratic memory requirements. Solving for the diagonal can be done cheaply, -with sub-quadratic memory requirements. +"""Deprecated utilities kept for backward compatibility. """ -from typing import Any +import abc +from typing import Any, Protocol import jax from jax import flatten_util import jax.numpy as jnp -from optax.second_order import _base - def _ravel(p: Any) -> jax.Array: return flatten_util.ravel_pytree(p)[0] +class LossFn(Protocol): + """A loss function to be optimized.""" + + @abc.abstractmethod + def __call__( + self, params: Any, inputs: jax.Array, targets: jax.Array + ) -> jax.Array: + ... + + def hvp( - loss: _base.LossFn, + loss: LossFn, v: jax.Array, params: Any, inputs: jax.Array, @@ -41,6 +46,8 @@ def hvp( ) -> jax.Array: """Performs an efficient vector-Hessian (of `loss`) product. + .. deprecated: 0.2 + Args: loss: the loss function. v: a vector of size `ravel(params)`. @@ -58,13 +65,15 @@ def hvp( def hessian_diag( - loss: _base.LossFn, + loss: LossFn, params: Any, inputs: jax.Array, targets: jax.Array, ) -> jax.Array: """Computes the diagonal hessian of `loss` at (`inputs`, `targets`). + .. deprecated: 0.2 + Args: loss: the loss function. params: model parameters. @@ -78,3 +87,27 @@ def hessian_diag( vs = jnp.eye(_ravel(params).size) comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets))) return jax.vmap(comp)(vs) + + +def fisher_diag( + negative_log_likelihood: LossFn, + params: Any, + inputs: jax.Array, + targets: jax.Array, +) -> jax.Array: + """Computes the diagonal of the (observed) Fisher information matrix. + + Args: + negative_log_likelihood: the negative log likelihood function with expected + signature `loss = fn(params, inputs, targets)`. + params: model parameters. + inputs: inputs at which `negative_log_likelihood` is evaluated. + targets: targets at which `negative_log_likelihood` is evaluated. + + Returns: + An Array corresponding to the product to the Hessian of + `negative_log_likelihood` evaluated at `(params, inputs, targets)`. + """ + return jnp.square( + _ravel(jax.grad(negative_log_likelihood)(params, inputs, targets)) + ) diff --git a/optax/second_order/_fisher.py b/optax/second_order/_fisher.py deleted file mode 100644 index 11d7d2e05..000000000 --- a/optax/second_order/_fisher.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions for computing diagonals of the fisher information matrix. - -Computing the Fisher matrix for neural networks is typically intractible due to -the quadratic memory requirements. Solving for the diagonal can be done cheaply, -with sub-quadratic memory requirements. -""" - -from typing import Any - -import jax -from jax import flatten_util -import jax.numpy as jnp - -from optax.second_order import _base - - -def _ravel(p: Any) -> jax.Array: - return flatten_util.ravel_pytree(p)[0] - - -def fisher_diag( - negative_log_likelihood: _base.LossFn, - params: Any, - inputs: jax.Array, - targets: jax.Array, -) -> jax.Array: - """Computes the diagonal of the (observed) Fisher information matrix. - - Args: - negative_log_likelihood: the negative log likelihood function with - expected signature `loss = fn(params, inputs, targets)`. - params: model parameters. - inputs: inputs at which `negative_log_likelihood` is evaluated. - targets: targets at which `negative_log_likelihood` is evaluated. - - Returns: - An Array corresponding to the product to the Hessian of - `negative_log_likelihood` evaluated at `(params, inputs, targets)`. - """ - return jnp.square( - _ravel(jax.grad(negative_log_likelihood)(params, inputs, targets))) diff --git a/optax/second_order/_hessian_test.py b/optax/second_order/_hessian_test.py deleted file mode 100644 index da38cfa31..000000000 --- a/optax/second_order/_hessian_test.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `hessian.py`.""" - -import functools - -from absl.testing import absltest - -import chex -from flax import linen as nn -import jax -import jax.numpy as jnp -import numpy as np - -from optax.second_order import _hessian - - -NUM_CLASSES = 2 -NUM_SAMPLES = 3 -NUM_FEATURES = 4 - - -class HessianTest(chex.TestCase): - - def setUp(self): - super().setUp() - - self.data = np.random.rand(NUM_SAMPLES, NUM_FEATURES) - self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES) - - class MLP(nn.Module): - """A simple multilayer perceptron model for image classification.""" - - @nn.compact - def __call__(self, x): - # Flattens images in the batch. - x = x.reshape((x.shape[0], -1)) - x = nn.Dense(features=5)(x) - x = nn.relu(x) - x = nn.Dense(features=NUM_CLASSES)(x) - return x - - net = MLP() - self.parameters = net.init({'params': jax.random.PRNGKey(0)}, self.data)[ - 'params' - ] - - def loss(params, inputs, targets): - log_probs = net.apply({'params': params}, inputs) - return -jnp.mean(jax.nn.one_hot(targets, NUM_CLASSES) * log_probs) - - self.loss_fn = loss - - def jax_hessian_diag(loss_fun, params, inputs, targets): - """This is the 'ground-truth' obtained via the JAX library.""" - flat_params, unravel_fn = jax.flatten_util.ravel_pytree(params) - - def flattened_loss(flat_params): - return loss_fun(unravel_fn(flat_params), inputs, targets) - - flat_hessian = jax.hessian(flattened_loss)(flat_params) - return jnp.diag(flat_hessian) - - self.hessian_diag = jax_hessian_diag( - self.loss_fn, self.parameters, self.data, self.labels) - - @chex.all_variants - def test_hessian_diag(self): - hessian_diag_fn = self.variant( - functools.partial(_hessian.hessian_diag, self.loss_fn)) - actual = hessian_diag_fn(self.parameters, self.data, self.labels) - np.testing.assert_array_almost_equal(self.hessian_diag, actual, 5) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax/second_order/_oracles.py b/optax/second_order/_oracles.py new file mode 100644 index 000000000..2deca4683 --- /dev/null +++ b/optax/second_order/_oracles.py @@ -0,0 +1,443 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to access second-order oracles efficiently.""" + +import functools +from typing import Any, Callable, NamedTuple, Union + +import chex +import jax +from optax._src import base +from optax._src import utils + +Scalar = Union[jax.Array, float] + + +def make_hvp_fn( + fn: Callable[..., Union[Scalar, tuple[Scalar, Any]]], + params: base.Params, + has_aux: bool = False, + **fn_kwargs, +) -> tuple[ + Union[tuple[Scalar, Any], Scalar], + base.Updates, + Callable[[base.Params], base.Params], +]: + r"""Instantiates Hessian vector product (hvp) function. + + In its simplest usage (for ``has_aux=False`` and ``**fn_kwargs`` empty), + this method returns the value of the function ``fn`` at the current + ``params``, the gradient of ``fn`` at ``params``, and a function ``hvp_fn`` + that gives acces to Hessian vector products. + In equation, this method returns + + .. math:: + + (f(w), \nabla f(w), v \rightarrow \nabla^2 f(w)v), + + where :math:`f` denotes the function ``fn``, and :math:`w` the parameters + ``params``. The output :math:`v \rightarrow \nabla^2 f(w)v` gives access to + Hessian vector products on any tangent vector :math:`v`. + + For ``has_aux=False`` and ``**fn_kwargs`` not empty, this method returns + :math:`(f(w; x), \nabla f(w; x), v \rightarrow \nabla^2 f(w; x)v)` + for :math:`x` some additional inputs fed into the method by keyword-only + arguments given in ``**fn_kwargs``. + + For ``has_aux=True`` and ``**fn_kwargs`` not empty, this method returns + :math:`((f(w; x), a), \nabla f(w; x), v \rightarrow \nabla^2 f(w; x)v)` + where :math:`a` is some auxiliary outputs returned by the function. + + Examples: + >>> import jax.numpy as jnp + >>> from optax import second_order + >>> # Example without auxiliary output + >>> def fn(params, x, y): + ... logits = x.dot(params) + ... return 0.5*jnp.sum((logits - y)**2) + >>> params = jnp.array([1., 2., 3.]) + >>> x = jnp.array([[1., 2., 3.], [4., 5., 6.]]) + >>> y = jnp.array([1., 2.]) + >>> value, grad, hvp_fn = second_order.make_hvp_fn(fn, params, x=x, y=y) + >>> print(value, grad) + 534.5 [133. 176. 219.] + >>> tangents = jnp.array([1., 2., 3.]) + >>> hvp = hvp_fn(tangents) + >>> print(hvp) + [142. 188. 234.] + >>> # Example with auxiliary outputs + >>> def fn_with_aux(params, x, y): + ... logits = x.dot(params) + ... return 0.5*jnp.sum((logits - y)**2), logits + >>> (value, aux), grad, hvp_fn = second_order.make_hvp_fn( + ... fn_with_aux, params, x=x, y=y, has_aux=True + ... ) + >>> print(aux) + [14. 32.] + + .. note:: + This function is akin to :func:`jax.vjp` in the sense that it instantiates + the hvp function rather than directly computing the hvp value. As a vjp, it + stores in memory some intermediate computations to give access to the hvp + function. + When the hvp needs to be accessed multiple times, this function can reuse + the information stored memory so that the function is not reevaluated each + time. + + .. seealso:: :func:`optax.second_order.hvp_call` + + + Args: + fn: function to compute Hessian vector products from. (must be twice + differentiable in JAX) Must return either a scalar (if ``has_aux = + False``) or a pair of (scalar, aux) with aux some auxiliary output (if + ``has_aux = True``). + params: pytree of parameters at which we define the hvp. + has_aux: whether the function has auxiliary outputs or not. + **fn_kwargs: additional parameters for the function in keyword format. + + Returns: + ``(value, grad, hvp_fn)`` if ``has_aux = False``. + + ``((value, aux), grad, hvp_fn)`` if ``has_aux=True``. + + where + + - ``value`` is the value of ``fn`` at ``params`` and ``**fn_kwargs``. + + - ``grad`` is the gradient of ``fn`` at ``params`` and ``**fn_kwargs``. + + - ``hvp_fn`` is a function that takes some pytree of tangent direction + ``tangent`` of the same shape as ``params`` and returns the product + ``hvp_fn(tangent)`` of the Hessian of ``fn`` at ``params`` and + ``**fn_kwargs`` with the ``tangent`` direction. + + - ``aux`` is some auxiliary output returned by the function. + """ + + def grad_and_value_fn(x): + value_and_maybe_aux, grad = jax.value_and_grad(fn, has_aux=has_aux)( + x, **fn_kwargs + ) + return grad, value_and_maybe_aux + + grad, hvp_fn, value_and_maybe_aux = jax.linearize( + grad_and_value_fn, params, has_aux=True + ) + return value_and_maybe_aux, grad, hvp_fn + + +def hvp_call( + fn: Callable[..., Union[Scalar, tuple[Scalar, Any]]], + params: base.Params, + tangents: base.Params, + has_aux: bool = False, + **fn_kwargs, +) -> tuple[Union[Scalar, tuple[Scalar, Any]], base.Updates, base.Params]: + r"""Computes hessian vector product (hvp) directly by jvp on top of gradient. + + In its simplest usage (for ``has_aux=False`` and ``**fn_kwargs`` empty), + this method returns the value of the function ``fn`` at the current + ``params``, the gradient of ``fn`` at ``params``, and the Hessian of ``fn`` + at params multipled by the ``tangents`` direction. + In equation, this method returns + + .. math:: + + (f(w), \nabla f(w), \nabla^2 f(w)v) + + where :math:`f` denotes the function ``fn``, :math:`w` the parameters + ``params``, and :math:`v` the tangent direction ``tangent``. + + For ``has_aux=False`` and ``**fn_kwargs`` not empty, this method returns + :math:`(f(w; x), \nabla f(w; x), \nabla^2 f(w; x)v)`, + for :math:`x` some additional inputs fed into the method by keyword-only + arguments given in ``**fn_kwargs``. + + For ``has_aux=True`` and ``**fn_kwargs`` not empty, this method returns + :math:`((f(w; x), a), \nabla f(w; x), \nabla^2 f(w; x)v)`, + where :math:`a` is some auxiliary outputs returned by the function. + + Examples: + >>> import jax.numpy as jnp + >>> from optax import second_order + >>> # Example without auxiliary output + >>> def fn(params, x, y): + ... logits = x.dot(params) + ... return 0.5*jnp.sum((logits - y)**2) + >>> params = jnp.array([1., 2., 3.]) + >>> tangents = jnp.array([1., 2., 3.]) + >>> x = jnp.array([[1., 2., 3.], [4., 5., 6.]]) + >>> y = jnp.array([1., 2.]) + >>> value, grad, hvp = second_order.hvp_call( + ... fn, params, tangents, x=x, y=y + ... ) + >>> print(value, grad) + 534.5 [133. 176. 219.] + >>> print(hvp) + [142. 188. 234.] + >>> # Example with auxiliary outputs + >>> def fn_with_aux(params, x, y): + ... logits = x.dot(params) + ... return 0.5*jnp.sum((logits - y)**2), logits + >>> (value, aux), grad, hvp = second_order.hvp_call( + ... fn_with_aux, params, tangents, x=x, y=y, has_aux=True + ... ) + >>> print(aux) + [14. 32.] + + .. note:: + This function is akin to :func:`jax.jvp` in the sense that it directly + returns the desired hvp value. There is no lingering memory cost as with + a :func:`jax.vjp` or in :func:`optax.second_order.make_hvp_fn`. + However, when the hvp needs to be accessed multiple times, this method + will reevaluate the function as many times. In other words, + :func:`optax.second_order.make_hvp_fn` may be preferred when the Hessian + must be accessed multiple times while :func:`optax.second_order.hvp_call` + may be preferred for a single call the hvp + + .. seealso:: :func:`optax.second_order.make_hvp_fn` + + Args: + fn: function to compute Hessian vector products from (must be twice + differentiable in JAX) Must return either a scalar (if ``has_aux = + False``) or a pair of (scalar, aux) with aux some auxiliary output (if + ``has_aux = True``). + params: pytree of parameters at which we define the hvp. + tangents: pytree of tangent direction along which we compute the hvp. of the + same shape as ``params``. + has_aux: whether the function has auxiliary outputs or not. + **fn_kwargs: additional parameters for the function in keyword format. + + Returns: + ``(value, grad, hvp)`` if ``has_aux = False``. + + ``((value, aux), grad, hvp)`` if ``has_aux=True``. + + where + + - ``value`` is the value of ``fn`` at ``params`` and ``**fn_kwargs``. + + - ``grad`` is the gradient of ``fn```` at ``params`` and ``**fn_kwargs``. + + - ``hvp`` is the product of the Hessian of ``fn`` at ``params`` and + ``**fn_kwargs`` with the ``tangent`` direction. + + - ``aux`` is some auxiliary output returned by the function. + """ + + def grad_and_value_fn(x): + value_and_maybe_aux, grad = jax.value_and_grad(fn, has_aux=has_aux)( + x, **fn_kwargs + ) + return grad, value_and_maybe_aux + + grad, hvp, value_and_maybe_aux = jax.jvp( # pylint: disable=unbalanced-tuple-unpacking + grad_and_value_fn, (params,), (tangents,), has_aux=True + ) + return value_and_maybe_aux, grad, hvp + + +class InnerOuterAux(NamedTuple): + """Auxiliary outputs of two functions inner_fn, outer_fn.""" + + inner: Any + outer: Any + + +def make_gnvp_fn( + inner_fn: Callable[..., Union[chex.ArrayTree, tuple[chex.ArrayTree, Any]]], + outer_fn: Callable[..., Union[Scalar, tuple[Scalar, Any]]], + params: base.Params, + inner_fn_has_aux: bool = False, + outer_fn_has_aux: bool = False, + **fn_kwargs, +) -> tuple[ + Union[Scalar, tuple[Scalar, InnerOuterAux]], + base.Updates, + Callable[[base.Params], base.Params], +]: + r"""Instantiates Gauss-Newton vector product (gnvp). + + For a composition :math:`f\circ g`, where :math:`g` is ``inner_fn`` and + :math:`f` is ``outer_fn``, this method computes for + ``inner_fn_has_aux=False``, ``outer_fn_has_aux=False``, and empty + ``**fn_kwargs``, + + .. math:: + + (f(g(w)), \nabla (f\circ g)(w), + v \rightarrow J_g(w)^\top H_f(z) J_g(w) v )), + + where :math:`J_g(w)` is the Jacobian of :math:`g` at :math:`w`, + :math:`H_f(z)` is the Hessian of :math:`f` at :math:`z = g(w)`. + The output :math:`v \rightarrow J_g(w)^T H_f(z) J_g(w)` gives access to + Gauss-Newton vector products on any tangent vector :math:`v`. + + If ``**fn_kwargs`` is not empty, the method splits ``**fn_kwargs`` into + keyword-only arguments ``inner_fn_kwargs`` and ``outer_fn_kwargs`` + by examining the signatures of ``inner_fn`` and ``outer_fn``. + The method then returns (still for + ``inner_fn_has_aux=False``, ``outer_fn_has_aux=False``) + + .. math:: + + (f(g(w; x); y), \nabla (f(\cdot; y) \circ g(\cdot; x))(w), + v \rightarrow J_g(w; x)^\top H_f(z; y) J_g(w; x) v )), + + where :math:`x` and :math:`y` are ``inner_fn_kwargs`` and ``outer_fn_kwargs`` + respectively. + + If ``inner_fn_has_aux=True`` or ``outer_fn_has_aux=True``, this method returns + (presented for ``**fn_kwargs`` empty for simplicity) + + .. math:: + ((f(g(w)), (a_i, a_o)), \nabla (f\circ g)(w), + v \rightarrow J_g(w)^\top H_f(z) J_g(w) v )), + + where :math:`a_i` and :math:`a_o` are the auxiliary outputs returned by, + respectively ``inner_fn`` and ``outer_fn``. If e.g. ``inner_fn_has_aux=True`` + and ``outer_fn_has_aux=False``, the function still returns :math:`(a_i, a_o)` + but with :math:`a_o` ``None``. + + Examples: + >>> import jax.numpy as jnp + >>> from optax import second_order + >>> # Example without auxiliary output + >>> def net(params, x): + ... return x.dot(params) + >>> def loss(logits, y): + ... return 0.5*jnp.sum((logits - y)**2) + >>> params = jnp.array([1., 2., 3.]) + >>> x = jnp.array([[1., 2., 3.], [4., 5., 6.]]) + >>> y = jnp.array([1., 2.]) + >>> value, grad, gnvp_fn = second_order.make_gnvp_fn( + ... net, loss, params, x=x, y=y + ... ) + >>> print(value, grad) + 534.5 [133. 176. 219.] + >>> tangents = jnp.array([1., 2., 3.]) + >>> gnvp = gnvp_fn(tangents) + >>> print(gnvp) + [142. 188. 234.] + >>> # Note that the same result would be obtained using hvp_fn since net + >>> # (the inner_fn) is linear here. + >>> # Example with auxiliary outputs + >>> def loss_with_aux(logits, y): + ... return 0.5*jnp.sum((logits - y)**2), logits + >>> (value, aux), grad, gnvp_fn = second_order.make_gnvp_fn( + ... net, loss_with_aux, params, x=x, y=y, outer_fn_has_aux=True + ... ) + >>> print(aux) + InnerOuterAux(inner=None, outer=Array([14., 32.], dtype=float32)) + + Args: + inner_fn: inner function of the composition, :math:`g` in the formula above. + Must be differentiable in JAX. Can return a pytree (if ``inner_fn_has_aux + = False``) or a pair of (pytree, aux) with aux some auxiliary output (if + ``inner_fn_has_aux = True``). The output of inner_fn must match the first + argument of outer_fn. + outer_fn: outer function of the composition, :math:`f` in the formula above. + Must be twice differentiable in JAX. Must return either a scalar (if + ``outer_fn_has_aux = False``) or a pair of (scalar, aux) with aux some + auxiliary output (if ``outer_fn_has_aux = True``). + params: parameters of the composition, :math:`w` in the formula above. + inner_fn_has_aux: whether the inner function returns auxiliary outputs. + outer_fn_has_aux: whether the outer function returns auxiliary outputs. + **fn_kwargs: additional parameters for the composition in keyword format. If + ``**fn_kwargs`` is not empty, the method splits ``**fn_kwargs`` into + keyword-only arguments ``inner_fn_kwargs`` and ``outer_fn_kwargs`` by + examining the signatures of ``inner_fn`` and ``outer_fn``. + + Returns: + ``(value, grad, gnvp_fn)`` if ``(inner_has_aux or outer_has_aux) = False``. + + ``((value, aux), grad, gnvp_fn)`` + if ``(inner_has_aux or outer_has_aux) = True``. + + where + + - ``value`` is the value of the composition of ``inner_fn`` and ``outer_fn`` + at ``params`` and ``**fn_kwargs``. + + - ``grad`` is the gradient of the composition of ``inner_fn`` and + ``outer_fn`` at ``params`` and ``**fn_kwargs``. + + - ``gnvp_fn`` is a function that takes some pytree of tangent direction + ``tangent`` of the same shape as ``params`` and returns the product + ``gnvp_fn(tangent)`` of the Gauss-Newton matrix defined from the composition + of ``inner_fn`` and ``outer_fn`` evaluated at ``params`` and + ``**fn_kwargs`` with a ``tangent`` direction. + + - ``aux`` is a ``NameTuple`` with entries ``inner``, ``outer`` for + the auxiliary outputs of ``inner_fn`` and ``outer_fn`` respectively. + If e.g. ``inner_fn_has_aux=True`` and ``outer_fn_has_aux=False``, + then ``aux.outer`` is ``None``, or, equivalently, ``aux[1] = None``. + """ + + (inner_fn_kwargs, outer_fn_kwargs), remaining_kwargs = ( + utils._extract_fns_kwargs( # pylint: disable=protected-access + (inner_fn, outer_fn), fn_kwargs + ) + ) + if remaining_kwargs: + raise ValueError( + f'Some arguments {remaining_kwargs} are not passed to inner_fn nor ' + 'outer_fn.' + ) + + inner_fn_aux = None + outer_fn_aux = None + + # Reduce inner fn to a single input function with auxiliary outputs + inner_fn_ = functools.partial(inner_fn, **inner_fn_kwargs) + + # Instantiates jacobian vector product (jvp) + if inner_fn_has_aux: + outputs, inner_jvp_fn, inner_fn_aux = jax.linearize( + inner_fn_, params, has_aux=True + ) + else: + outputs, inner_jvp_fn = jax.linearize(inner_fn_, params, has_aux=False) + + # Get jacobian transpose vector product (vjp) by linear transposition + inner_vjp_fn_ = jax.linear_transpose(inner_jvp_fn, params) + + # jax.linear_transpose returns tuples (like a vjp), we won't deal with + # multiple parameters so we alias it to return just what we want + # (not a tuple (grad,), but just grad) + inner_vjp_fn = lambda x: inner_vjp_fn_(x)[0] + + # Get hvp of outer function, with associated value, aux and gradient + value_and_maybe_aux, outer_grad, outer_hvp_fn = make_hvp_fn( + outer_fn, outputs, has_aux=outer_fn_has_aux, **outer_fn_kwargs + ) + if outer_fn_has_aux: + value, outer_fn_aux = value_and_maybe_aux + else: + value = value_and_maybe_aux + + # Compute overall gradient by backpropagating outer gradient through inner + # vjp + grad = inner_vjp_fn(outer_grad) + + # Creates gnvp fnction by adequate composition. + # We make gnvp a valid Pytree to enable jitting make_gnvp (see tests). + gnvp_fn = lambda tangents: inner_vjp_fn(outer_hvp_fn(inner_jvp_fn(tangents))) + + if inner_fn_has_aux or outer_fn_has_aux: + return (value, InnerOuterAux(inner_fn_aux, outer_fn_aux)), grad, gnvp_fn + else: + return value, grad, gnvp_fn diff --git a/optax/second_order/_oracles_test.py b/optax/second_order/_oracles_test.py new file mode 100644 index 000000000..ca3488de0 --- /dev/null +++ b/optax/second_order/_oracles_test.py @@ -0,0 +1,306 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for second order oracles from `optax.second_order._oracles.py`.""" + +import functools + +from absl.testing import absltest +from absl.testing import parameterized +import chex +import flax.linen as nn +import jax +from jax import flatten_util +import jax.numpy as jnp +import jax.random as jrd +import jax.tree_util as jtu +from optax.second_order import _oracles as oracles + + +def random_split_like_tree(rng_key, target_tree=None, treedef=None): + """Split keys to match structure of target tree or of tree_def.""" + if treedef is None: + treedef = jtu.tree_structure(target_tree) + keys = jrd.split(rng_key, treedef.num_leaves) + return jtu.tree_unflatten(treedef, keys) + + +def tree_random_normal_like(rng_key, target_tree): + """Create tree with normal random entries of the same shape as target tree.""" + keys_tree = random_split_like_tree(rng_key, target_tree) + return jtu.tree_map( + lambda l, k: jrd.normal(k, l.shape, l.dtype), + target_tree, + keys_tree, + ) + + +def get_random_sym_matrix(key, dim): + mat = jrd.normal(key, (dim, dim)) + mat = mat + mat.transpose() + return mat + + +class MLP(nn.Module): + num_outputs: int + hidden_sizes: list[int] + + @nn.compact + def __call__(self, x): + for num_hidden in self.hidden_sizes: + x = nn.Dense(num_hidden)(x) + x = nn.gelu(x) + return nn.Dense(self.num_outputs)(x) + + +def setup_mlp(input_dim, output_dim, hidden_sizes, key): + mlp = MLP(num_outputs=output_dim, hidden_sizes=hidden_sizes) + key, key_params, key_input, key_output = jrd.split(key, 4) + x = jrd.normal(key_input, (input_dim,)) + y = jrd.normal(key_output, (output_dim,)) + params = mlp.init(key_params, jnp.ones(input_dim)) + return mlp, params, x, y, key + + +class OraclesTest(chex.TestCase): + """Tests for second order oracles from `optax.second_order._oracles.py`.""" + + @chex.all_variants + def test_hvp_mlp_basic(self): + # Setup problem + """Test hvp on mlp.""" + key = jrd.PRNGKey(0) + input_dim, output_dim = 4, 2 + hidden_sizes = [4, 4] + mlp, params, x, _, key = setup_mlp(input_dim, output_dim, hidden_sizes, key) + tangents = tree_random_normal_like(key, params) + + def fn(params): + z = mlp.apply(params, x) + return jnp.sum(z**2) + + # Get reference quantities by flattening params + params_flat, unravel = flatten_util.ravel_pytree(params) + tangents_flat, _ = flatten_util.ravel_pytree(tangents) + + def fn_flat(params_flat): + params = unravel(params_flat) + return fn(params) + + value, grad = jax.value_and_grad(fn)(params) + hessian = jax.hessian(fn_flat)(params_flat) + hvp = hessian.dot(tangents_flat) + hvp = unravel(hvp) + + # Computations via make_hvp_fn + make_hvp_fn_ = functools.partial(oracles.make_hvp_fn, fn=fn) + make_hvp_fn = self.variant(make_hvp_fn_) + value1, grad1, hvp_fn = make_hvp_fn(params=params) + hvp1 = hvp_fn(tangents) + + # Computations via hvp_call + hvp_call_ = functools.partial(oracles.hvp_call, fn=fn) + hvp_call = self.variant(hvp_call_) + value2, grad2, hvp2 = hvp_call(params=params, tangents=tangents) + + # Check everything + # If on tpu or gpu matrix multiplications are done at half-precision + # so we should test at that precision. + if jax.default_backend() in ['gpu', 'tpu']: + atol = 10*jnp.finfo('bfloat16').eps + rtol = 10*jnp.finfo('bfloat16').eps + else: + atol = 100*jnp.finfo('float32').eps + rtol = 100*jnp.finfo('float32').eps + chex.assert_trees_all_close(value, value1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(value2, value2, atol=atol, rtol=rtol) + chex.assert_trees_all_close(grad, grad1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(grad1, grad2, atol=atol, rtol=rtol) + chex.assert_trees_all_close(hvp, hvp1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(hvp1, hvp2, atol=atol, rtol=rtol) + + @chex.all_variants + def test_hvp_mlp_with_aux_and_fn_kwargs(self): + """Test hvp on mlp.""" + key = jrd.PRNGKey(0) + input_dim, output_dim = 4, 2 + hidden_sizes = [4, 4] + mlp, params, x, y, key = setup_mlp(input_dim, output_dim, hidden_sizes, key) + tangents = tree_random_normal_like(key, params) + + def fn(params, x, y): + z = mlp.apply(params, x) + return jnp.sum((z - y) ** 2), z + + # Get reference quantities by flattening params + params_flat, unravel = flatten_util.ravel_pytree(params) + tangents_flat, _ = flatten_util.ravel_pytree(tangents) + + def fn_flat(params_flat, x, y): + params = unravel(params_flat) + return fn(params, x, y) + + (value, aux), grad = jax.value_and_grad(fn, has_aux=True)(params, x, y) + hessian = jax.hessian(fn_flat, has_aux=True)(params_flat, x, y)[0] + hvp = hessian.dot(tangents_flat) + hvp = unravel(hvp) + + # Computations via make_hvp_fn + make_hvp_fn_ = functools.partial(oracles.make_hvp_fn, fn=fn, has_aux=True) + make_hvp_fn = self.variant(make_hvp_fn_) + (value1, aux1), grad1, hvp_fn = make_hvp_fn(params=params, x=x, y=y) + hvp1 = hvp_fn(tangents) + + # Computations via hvp_call + hvp_call_ = functools.partial(oracles.hvp_call, fn=fn, has_aux=True) + hvp_call = self.variant(hvp_call_) + (value2, aux2), grad2, hvp2 = hvp_call( + params=params, tangents=tangents, x=x, y=y + ) + + # Check everything + # If on tpu or gpu matrix multiplications are done at half-precision + # so we should test at that precision. + if jax.default_backend() in ['gpu', 'tpu']: + atol = 10*jnp.finfo('bfloat16').eps + rtol = 10*jnp.finfo('bfloat16').eps + else: + atol = 100*jnp.finfo('float32').eps + rtol = 100*jnp.finfo('float32').eps + chex.assert_trees_all_close(value, value1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(value2, value2, atol=atol, rtol=rtol) + chex.assert_trees_all_close(aux, aux1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(aux1, aux2, atol=atol, rtol=rtol) + chex.assert_trees_all_close(grad, grad1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(grad1, grad2, atol=atol, rtol=rtol) + chex.assert_trees_all_close(hvp, hvp1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(hvp1, hvp2, atol=atol, rtol=rtol) + + @parameterized.product( + inner_has_aux=[True, False], outer_has_aux=[True, False] + ) + def test_gnvp_mlp_with_outer_aux_and_fn_kwargs( + self, inner_has_aux, outer_has_aux + ): + # Setup problem + key = jrd.PRNGKey(0) + input_dim, output_dim = 4, 2 + hidden_sizes = [4, 4] + mlp, params, x, y, key = setup_mlp(input_dim, output_dim, hidden_sizes, key) + params_flat, unravel = flatten_util.ravel_pytree(params) + tangents = tree_random_normal_like(key, params) + tangents_flat, _ = flatten_util.ravel_pytree(tangents) + + # Define inner and outer functions with or without aux + if inner_has_aux: + + def inner_fn(params, x): + return mlp.apply(params, x), x + + else: + + def inner_fn(params, x): + return mlp.apply(params, x) + + if outer_has_aux: + + def outer_fn(outputs, y): + return jnp.sum((outputs - y) ** 2), outputs + + else: + + def outer_fn(outputs, y): + return jnp.sum((outputs - y) ** 2) + + # Define composition to get reference value and grad + def fn(params, x, y): + outputs_and_maybe_aux = inner_fn(params, x) + if inner_has_aux: + outputs, inner_aux = outputs_and_maybe_aux + else: + outputs = outputs_and_maybe_aux + inner_aux = None + + value_and_maybe_aux = outer_fn(outputs, y) + if outer_has_aux: + value, outer_aux = value_and_maybe_aux + else: + value = value_and_maybe_aux + outer_aux = None + return value, (inner_aux, outer_aux) + + # Get reference values (by flattening params for the gnvp) + (value, (inner_aux, outer_aux)), grad = jax.value_and_grad( + fn, has_aux=True + )(params, x, y) + + def inner_fn_flat(params_flat, x_flat): + params = unravel(params_flat) + if inner_has_aux: + return inner_fn(params, x_flat)[0] + else: + return inner_fn(params, x_flat) + + jac = jax.jacobian(inner_fn_flat)(params_flat, x) + outputs = inner_fn_flat(params_flat, x) + hessian_and_maybe_aux = jax.hessian(outer_fn, has_aux=outer_has_aux)( + outputs, y + ) + if outer_has_aux: + hessian = hessian_and_maybe_aux[0] + else: + hessian = hessian_and_maybe_aux + gnvp_flat = jac.T.dot(hessian.dot(jac.dot(tangents_flat))) + gnvp = unravel(gnvp_flat) + + # Computed by gnvp + make_gnvp_fn = functools.partial( + oracles.make_gnvp_fn, + inner_fn=inner_fn, + outer_fn=outer_fn, + inner_fn_has_aux=inner_has_aux, + outer_fn_has_aux=outer_has_aux, + ) + value_and_maybe_aux, grad1, gnvp_fn = make_gnvp_fn( + params=params, + x=x, + y=y, + ) + gnvp1 = gnvp_fn(tangents) + if inner_has_aux or outer_has_aux: + value1, (inner_aux1, outer_aux1) = value_and_maybe_aux + else: + value1 = value_and_maybe_aux + inner_aux1 = None + outer_aux1 = None + + # Check everything + # If on tpu or gpu matrix multiplications are done at half-precision + # so we should test at that precision. + if jax.default_backend() in ['gpu', 'tpu']: + atol = 10*jnp.finfo('bfloat16').eps + rtol = 10*jnp.finfo('bfloat16').eps + else: + atol = 100*jnp.finfo('float32').eps + rtol = 100*jnp.finfo('float32').eps + chex.assert_trees_all_close(value, value1, atol=atol, rtol=rtol) + if inner_has_aux or outer_has_aux: + chex.assert_trees_all_close(inner_aux, inner_aux1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(outer_aux, outer_aux1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(grad, grad1, atol=atol, rtol=rtol) + chex.assert_trees_all_close(gnvp, gnvp1, atol=atol, rtol=rtol) + + +if __name__ == '__main__': + absltest.main()