Skip to content

Commit

Permalink
Add a check in pi adjusted kronecker factors for damping being zero
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697172845
  • Loading branch information
joeljennings authored and KfacJaxDev committed Nov 16, 2024
1 parent b14fb6a commit 84ef869
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions kfac_jax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,6 @@ def pi_adjusted_kronecker_factors(

norms = jnp.array([psd_matrix_norm(f, norm_type=norm_type) for f in factors])

# Compute the normalized factors `u_i`, such that Trace(u_i) / dim(u_i) = 1
us = [fi / ni for fi, ni in zip(factors, norms)]

k = len(factors)

# TODO(jamesmartens,botev): consider making the use of special behavior for
Expand Down Expand Up @@ -665,7 +662,9 @@ def regular_case() -> tuple[Array, ...]:

u_hats = []

for u in us:
# Compute the normalized factors `u_i`, such that Trace(u_i) / dim(u_i) = 1
for fi, ni in zip(factors, norms):
u = fi / ni

if u.size == 1: # scalar case
u_hat = jnp.ones_like(u) # damping not used in the scalar factors
Expand All @@ -686,21 +685,14 @@ def zero_case() -> tuple[Array, ...]:
# In the special case where for some reason one of the factors is zero, then
# the we write each factor as `damping^(1/k) * I`.

c_k = jnp.power(damping, 1.0 / k)

u_hats = []

for u in us:

if u.ndim == 2:
u_hat = jnp.eye(u.shape[0], dtype=u.dtype)
d = lax.select(jnp.equal(damping, 0.0), 1e-8, damping)
c_k = jnp.power(d, 1.0 / k)

else:
u_hat = jnp.ones_like(u)

u_hats.append(u_hat * c_k)

return tuple(u_hats)
return tuple(
c_k * (jnp.eye(fi.shape[0], dtype=fi.dtype) if fi.ndim == 2 else
jnp.ones_like(fi))
for fi in factors
)

if get_special_case_zero_inv():
return lax.cond(jnp.greater(jnp.min(norms), 0.0), regular_case, zero_case)
Expand Down

0 comments on commit 84ef869

Please sign in to comment.