From 84ef869e31b25be04bd011afbb0335db4b5d0281 Mon Sep 17 00:00:00 2001 From: Joel Jennings Date: Sat, 16 Nov 2024 07:25:30 -0800 Subject: [PATCH] Add a check in pi adjusted kronecker factors for damping being zero PiperOrigin-RevId: 697172845 --- kfac_jax/_src/utils/math.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/kfac_jax/_src/utils/math.py b/kfac_jax/_src/utils/math.py index 3c43ddd..082d1b3 100644 --- a/kfac_jax/_src/utils/math.py +++ b/kfac_jax/_src/utils/math.py @@ -617,9 +617,6 @@ def pi_adjusted_kronecker_factors( norms = jnp.array([psd_matrix_norm(f, norm_type=norm_type) for f in factors]) - # Compute the normalized factors `u_i`, such that Trace(u_i) / dim(u_i) = 1 - us = [fi / ni for fi, ni in zip(factors, norms)] - k = len(factors) # TODO(jamesmartens,botev): consider making the use of special behavior for @@ -665,7 +662,9 @@ def regular_case() -> tuple[Array, ...]: u_hats = [] - for u in us: + # Compute the normalized factors `u_i`, such that Trace(u_i) / dim(u_i) = 1 + for fi, ni in zip(factors, norms): + u = fi / ni if u.size == 1: # scalar case u_hat = jnp.ones_like(u) # damping not used in the scalar factors @@ -686,21 +685,14 @@ def zero_case() -> tuple[Array, ...]: # In the special case where for some reason one of the factors is zero, then # the we write each factor as `damping^(1/k) * I`. - c_k = jnp.power(damping, 1.0 / k) - - u_hats = [] - - for u in us: - - if u.ndim == 2: - u_hat = jnp.eye(u.shape[0], dtype=u.dtype) + d = lax.select(jnp.equal(damping, 0.0), 1e-8, damping) + c_k = jnp.power(d, 1.0 / k) - else: - u_hat = jnp.ones_like(u) - - u_hats.append(u_hat * c_k) - - return tuple(u_hats) + return tuple( + c_k * (jnp.eye(fi.shape[0], dtype=fi.dtype) if fi.ndim == 2 else + jnp.ones_like(fi)) + for fi in factors + ) if get_special_case_zero_inv(): return lax.cond(jnp.greater(jnp.min(norms), 0.0), regular_case, zero_case)