Skip to content

Commit

Permalink
Some math utilities added.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695286988
  • Loading branch information
timothyn617 authored and KfacJaxDev committed Nov 11, 2024
1 parent 8a610fc commit deeb486
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
3 changes: 3 additions & 0 deletions kfac_jax/_src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@
invert_psd_matrices = math.invert_psd_matrices
inverse_sqrt_psd_matrices = math.inverse_sqrt_psd_matrices
stable_sqrt = math.stable_sqrt
cosine_similarity = math.cosine_similarity
make_flattened_func = math.make_flattened_func
make_flattened_hvp_func = math.make_flattened_hvp_func

del math

Expand Down
41 changes: 38 additions & 3 deletions kfac_jax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@
"""K-FAC utilities for various mathematical operations."""
import functools
import string
from typing import Callable, Sequence, Iterable, TypeVar
from typing import Callable, Iterable, Sequence, TypeVar

import jax
from jax import lax
from jax.experimental.sparse import linalg as experimental_splinalg
import jax.numpy as jnp
from jax.scipy import linalg

from kfac_jax._src.utils import types

import numpy as np
import optax
import tree
Expand Down Expand Up @@ -1155,3 +1153,40 @@ def _stable_sqrt_fwd(
_sqrt_bound_derivative.defjvp(_stable_sqrt_fwd)

stable_sqrt = functools.partial(_sqrt_bound_derivative, max_gradient=1000.0)


def cosine_similarity(v1, v2):
return jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2))


def make_flattened_func(func, arg):
"""Returns a function that has flattened input and output."""
_, unflatten_func = jax.flatten_util.ravel_pytree(arg)
flattened_func = lambda x: jax.flatten_util.ravel_pytree(
func(unflatten_func(x))
)[0]
return flattened_func


def make_flattened_hvp_func(grad_func, primals):
"""Returns a function that computes the Hessian-vector product (HVP).
Returns an HVP function which takes in and returns a flattened vector.
Args:
grad_func: A function whose JVP will be computed to obtain the HVP.
primals: The primals for the HVP.
Returns:
The flattened HVP function.
"""

flattened_primals, _ = jax.flatten_util.ravel_pytree(primals)
flattened_grad_func = make_flattened_func(grad_func, primals)

def hvp_func(flattened_tangents):
return jax.jvp(
flattened_grad_func, (flattened_primals,), (flattened_tangents,)
)[1]

return hvp_func

0 comments on commit deeb486

Please sign in to comment.