Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled greater range of preconditioner powers. Some math utilities added. #292

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion kfac_jax/_src/curvature_blocks/kronecker_factored.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
from typing import Any, Sequence

import jax
import jax.numpy as jnp
from kfac_jax._src import layers_and_loss_tags as tags
from kfac_jax._src import patches_second_moment as psm
Expand Down Expand Up @@ -280,7 +281,7 @@ def _multiply_matpower_unscaled(

else:

if power != -1 and power != -0.5:
if power not in [-1, -0.5, 0.5]:
raise NotImplementedError(
f"Approximations for power {power} is not yet implemented."
)
Expand All @@ -305,6 +306,19 @@ def _multiply_matpower_unscaled(
factors = utils.invert_psd_matrices(factors)
elif power == -0.5:
factors = utils.inverse_sqrt_psd_matrices(factors)
# TODO(timothycnguyen): Hacky psd square root. Will find a better way.
elif power == 0.5:
inverse_sqrt_factors = utils.inverse_sqrt_psd_matrices(factors)

def matmul(x, y):
if x.ndim == y.ndim == 2:
return jnp.dot(x, y)
assert x.ndim == y.ndim == 1
return x * y

factors = jax.tree_util.tree_map(
matmul, factors, inverse_sqrt_factors
)
else:
raise NotImplementedError()

Expand Down
2 changes: 2 additions & 0 deletions kfac_jax/_src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
scalar_div = math.scalar_div
weighted_sum_of_objects = math.weighted_sum_of_objects
sum_of_objects = math.sum_objects
pytree_size = math.pytree_size
inner_product = math.inner_product
symmetric_matrix_inner_products = math.symmetric_matrix_inner_products
matrix_of_inner_products = math.matrix_of_inner_products
Expand All @@ -131,6 +132,7 @@
invert_psd_matrices = math.invert_psd_matrices
inverse_sqrt_psd_matrices = math.inverse_sqrt_psd_matrices
stable_sqrt = math.stable_sqrt
cosine_similarity = math.cosine_similarity

del math

Expand Down
16 changes: 13 additions & 3 deletions kfac_jax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@
"""K-FAC utilities for various mathematical operations."""
import functools
import string
from typing import Callable, Sequence, Iterable, TypeVar
from typing import Callable, Iterable, Sequence, TypeVar

import jax
from jax import lax
from jax.experimental.sparse import linalg as experimental_splinalg
import jax.numpy as jnp
from jax.scipy import linalg

from kfac_jax._src.utils import types

import numpy as np
import optax
import tree
Expand Down Expand Up @@ -156,6 +154,13 @@ def sum_objects(objects: Sequence[TArrayTree]) -> TArrayTree:
return weighted_sum_of_objects(objects, [1] * len(objects))


def pytree_size(pytree):
"""Computes total size of pytree leaves."""
return jax.tree_util.tree_reduce(
lambda x, y: x + y, jax.tree_util.tree_map(jnp.size, pytree), 0
)


def _inner_product_float64(obj1: ArrayTree, obj2: ArrayTree) -> Array:
"""Computes inner product explicitly in float64 precision."""

Expand Down Expand Up @@ -1155,3 +1160,8 @@ def _stable_sqrt_fwd(
_sqrt_bound_derivative.defjvp(_stable_sqrt_fwd)

stable_sqrt = functools.partial(_sqrt_bound_derivative, max_gradient=1000.0)


def cosine_similarity(v1: ArrayTree, v2: ArrayTree) -> Array:
"""Computes the cosine similarity between flattened pytrees."""
return inner_product(v1, v2) / (norm(v1) * norm(v2))
16 changes: 15 additions & 1 deletion kfac_jax/_src/utils/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""K-FAC utilities for classes with staged methods."""

import functools
import numbers
from typing import Any, Callable, Sequence

import jax

from jax import lax
from kfac_jax._src.utils import misc
from kfac_jax._src.utils import parallel
from kfac_jax._src.utils import types


TArrayTree = types.TArrayTree


Expand Down Expand Up @@ -129,6 +131,18 @@ def replicate(self, obj: TArrayTree) -> TArrayTree:
else:
return obj

def pmean_if_pmap_wrapper(
self,
func: Callable[..., TArrayTree],
) -> Callable[..., TArrayTree]:
"""Wraps a function to perform a pmean if `multi_device`."""
if self.multi_device:
return lambda *args, **kwargs: lax.pmean(
func(*args, **kwargs), self.pmap_axis_name
)
else:
return func


def staged(
method: Callable[..., TArrayTree],
Expand Down
Loading