diff --git a/kfac_jax/_src/utils/__init__.py b/kfac_jax/_src/utils/__init__.py index cd5436b..e843352 100644 --- a/kfac_jax/_src/utils/__init__.py +++ b/kfac_jax/_src/utils/__init__.py @@ -131,6 +131,9 @@ invert_psd_matrices = math.invert_psd_matrices inverse_sqrt_psd_matrices = math.inverse_sqrt_psd_matrices stable_sqrt = math.stable_sqrt +cosine_similarity = math.cosine_similarity +make_flattened_func = math.make_flattened_func +make_flattened_hvp_func = math.make_flattened_hvp_func del math diff --git a/kfac_jax/_src/utils/math.py b/kfac_jax/_src/utils/math.py index 3ad4275..0b9ee8d 100644 --- a/kfac_jax/_src/utils/math.py +++ b/kfac_jax/_src/utils/math.py @@ -14,16 +14,14 @@ """K-FAC utilities for various mathematical operations.""" import functools import string -from typing import Callable, Sequence, Iterable, TypeVar +from typing import Callable, Iterable, Sequence, TypeVar import jax from jax import lax from jax.experimental.sparse import linalg as experimental_splinalg import jax.numpy as jnp from jax.scipy import linalg - from kfac_jax._src.utils import types - import numpy as np import optax import tree @@ -1155,3 +1153,40 @@ def _stable_sqrt_fwd( _sqrt_bound_derivative.defjvp(_stable_sqrt_fwd) stable_sqrt = functools.partial(_sqrt_bound_derivative, max_gradient=1000.0) + + +def cosine_similarity(v1, v2): + return jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2)) + + +def make_flattened_func(func, arg): + """Returns a function that has flattened input and output.""" + _, unflatten_func = jax.flatten_util.ravel_pytree(arg) + flattened_func = lambda x: jax.flatten_util.ravel_pytree( + func(unflatten_func(x)) + )[0] + return flattened_func + + +def make_flattened_hvp_func(grad_func, primals): + """Returns a function that computes the Hessian-vector product (HVP). + + Returns an HVP function which takes in and returns a flattened vector. + + Args: + grad_func: A function whose JVP will be computed to obtain the HVP. + primals: The primals for the HVP. + Returns: + The flattened HVP function. + """ + + flattened_primals, _ = jax.flatten_util.ravel_pytree(primals) + flattened_grad_func = make_flattened_func(grad_func, primals) + + def hvp_func(flattened_tangents): + return jax.jvp( + flattened_grad_func, (flattened_primals,), (flattened_tangents,) + )[1] + + return hvp_func +