From 0fba083620d21024b17cd7033fdbf53f6d8a1528 Mon Sep 17 00:00:00 2001 From: James Martens Date: Fri, 6 Oct 2023 09:55:59 -0700 Subject: [PATCH] - Adding norm_to_scale_identity_weight_per_block to multiply and update_cache methods of estimator which allows the identity_weight to be scaled differently for each block according to some kind of norm (or norm-like function) of the curvature for that block. - Fixing minor bug that would cause some curvature blocks to use an improperly scaled damping when multiplying with power=1 and use_cached=True, for classes that have non-trivial state_dependent_scale methods. - Adding whitespace to improve readability. PiperOrigin-RevId: 571364604 --- examples/optimizers.py | 97 ++++++++++++++++++- kfac_jax/_src/curvature_blocks.py | 126 +++++++++++++++++++++++-- kfac_jax/_src/curvature_estimator.py | 94 ++++++++++++++----- kfac_jax/_src/optimizer.py | 18 ++++ kfac_jax/_src/utils/math.py | 134 +++++++++++++++++++++++++-- 5 files changed, 421 insertions(+), 48 deletions(-) diff --git a/examples/optimizers.py b/examples/optimizers.py index 59bf75d..829ef31 100644 --- a/examples/optimizers.py +++ b/examples/optimizers.py @@ -73,6 +73,7 @@ def __init__( distributed_precon_apply: bool = True, num_samples: int = 1, should_vmap_samples: bool = False, + norm_to_scale_identity_weight_per_block: Optional[str] = None, ): """Initializes the curvature estimator and preconditioner. @@ -141,10 +142,14 @@ def __init__( '[fisher,ggn]_curvature_prop'``. (Default: 1) should_vmap_samples: Whether to use ``jax.vmap`` to compute samples when ``num_samples > 1``. (Default: False) + norm_to_scale_identity_weight_per_block: The name of a norm to use to + compute extra per-block scaling for the damping. See psd_matrix_norm() + in utils/math.py for the definition of these. (Default: None) """ self._l2_reg = l2_reg self._damping = damping self._damping_schedule = damping_schedule + if (self._damping_schedule is None) == (self._damping is None): raise ValueError( "Only one of `damping_schedule` or `damping` has to be specified." @@ -159,6 +164,10 @@ def __init__( self._use_cached_inverses = self._inverse_update_period != 1 self._use_exact_inverses = use_exact_inverses + self._norm_to_scale_identity_weight_per_block = ( + norm_to_scale_identity_weight_per_block + ) + # Curvature estimator self._estimator = kfac_jax.curvature_estimator.BlockDiagonalCurvature( func=value_func, @@ -180,6 +189,7 @@ def init( rng: PRNGKey, ) -> PreconditionState: """Initializes the preconditioner and returns the state.""" + return PreconditionState( count=jnp.array(0, dtype=jnp.int32), estimator_state=self.estimator.init( @@ -193,6 +203,7 @@ def init( @property def _exact_powers_to_cache(self) -> Optional[Union[int, Sequence[int]]]: + if self._use_exact_inverses and self._use_cached_inverses: return -1 else: @@ -200,6 +211,7 @@ def _exact_powers_to_cache(self) -> Optional[Union[int, Sequence[int]]]: @property def _approx_powers_to_cache(self) -> Optional[Union[int, Sequence[int]]]: + if not self._use_exact_inverses and self._use_cached_inverses: return -1 else: @@ -217,9 +229,12 @@ def pmap_axis_name(self): def get_identity_weight( self, state: PreconditionState ) -> Union[Array, float]: + damping = self._damping + if damping is None: damping = self._damping_schedule(state.count) + return damping + self._l2_reg def sync_estimator_state( @@ -227,36 +242,43 @@ def sync_estimator_state( state: PreconditionState, ) -> PreconditionState: """Syncs the estimator state.""" + return PreconditionState( count=state.count, estimator_state=self.estimator.sync( state.estimator_state, pmap_axis_name=self.pmap_axis_name), ) - def should_update_estimate_curvature( + def should_update_estimator_curvature( self, state: PreconditionState ) -> Union[Array, bool]: """Whether at the current step the preconditioner should update the curvature estimates.""" + if self._curvature_update_period == 1: return True + return state.count % self._curvature_update_period == 0 def should_sync_estimate_curvature( self, state: PreconditionState ) -> Union[Array, bool]: """Whether at the current step the preconditioner should synchronize (pmean) the curvature estimates.""" + # sync only before inverses are calculated (either for updating the # cache or for preconditioning). if not self._use_cached_inverses: return True + return self.should_update_inverse_cache(state) def should_update_inverse_cache( self, state: PreconditionState ) -> Union[Array, bool]: """Whether at the current step the preconditioner should update the inverse cache.""" + if not self._use_cached_inverses: return False + return state.count % self._inverse_update_period == 0 def maybe_update( @@ -266,6 +288,7 @@ def maybe_update( rng: PRNGKey, ) -> PreconditionState: """Updates the estimates if it is the right iteration.""" + # NOTE: This maybe update curvatures and inverses at an iteration. But # if curvatures should be accumulated for multiple iterations # before updating inverses (for micro-batching), call @@ -277,7 +300,9 @@ def maybe_update( rng=rng, sync=self.should_sync_estimate_curvature(state), ) + state = self.maybe_update_inverse_cache(state) + return PreconditionState(state.count, state.estimator_state) def _update_estimator_curvature( @@ -300,6 +325,7 @@ def _update_estimator_curvature( rng=rng, func_args=func_args, ) + return jax.lax.cond( sync, functools.partial(self.estimator.sync, @@ -322,7 +348,7 @@ def maybe_update_estimator_curvature( return self._maybe_update_estimator_state( state, - self.should_update_estimate_curvature(state), + self.should_update_estimator_curvature(state), self._update_estimator_curvature, func_args=func_args, rng=rng, @@ -336,6 +362,7 @@ def maybe_update_inverse_cache( state: PreconditionState, ) -> PreconditionState: """Updates the estimator state cache if it is the right iteration.""" + if state.count is None: raise ValueError( "PreconditionState is not initialized. Call" @@ -351,6 +378,7 @@ def maybe_update_inverse_cache( approx_powers=self._approx_powers_to_cache, eigenvalues=False, pmap_axis_name=self.pmap_axis_name, + norm_to_scale_identity_weight_per_block=self._norm_to_scale_identity_weight_per_block, ) def _maybe_update_estimator_state( @@ -361,12 +389,14 @@ def _maybe_update_estimator_state( **update_func_kwargs, ) -> PreconditionState: """Updates the estimator state if it should update.""" + estimator_state = lax.cond( should_update, functools.partial(update_func, **update_func_kwargs), lambda s: s, state.estimator_state, ) + return PreconditionState(state.count, estimator_state) def apply( @@ -375,6 +405,7 @@ def apply( state: PreconditionState, ) -> optax.Updates: """Preconditions (= multiplies the inverse curvature estimation matrix to) updates.""" + new_updates = self.estimator.multiply_inverse( state=state.estimator_state, parameter_structured_vector=updates, @@ -382,15 +413,22 @@ def apply( exact_power=self._use_exact_inverses, use_cached=self._use_cached_inverses, pmap_axis_name=self.pmap_axis_name, + norm_to_scale_identity_weight_per_block=self._norm_to_scale_identity_weight_per_block, ) + if self._norm_constraint is not None: + sq_norm_grads = kfac_jax.utils.inner_product(new_updates, updates) del updates + max_coefficient = jnp.sqrt(self._norm_constraint / sq_norm_grads) coeff = jnp.minimum(max_coefficient, 1) + new_updates = kfac_jax.utils.scalar_mul(new_updates, coeff) + else: del updates + return new_updates def multiply_curvature( @@ -412,9 +450,10 @@ def multiply_curvature( state=state.estimator_state, parameter_structured_vector=updates, identity_weight=self.get_identity_weight(state), - exact_power=self._use_exact_inverses, # this argument will not be used. - use_cached=self._use_cached_inverses, # this argument will not be used. + exact_power=self._use_exact_inverses, + use_cached=self._use_cached_inverses, pmap_axis_name=self.pmap_axis_name, + norm_to_scale_identity_weight_per_block=self._norm_to_scale_identity_weight_per_block, ) return updates @@ -506,10 +545,12 @@ def __init__( self._value_func_has_aux = value_func_has_aux self._value_func_has_state = value_func_has_state self._value_func_has_rng = value_func_has_rng + if not callable(learning_rate): self._learning_rate = lambda _: learning_rate else: self._learning_rate = learning_rate + # Wraps the optax optimizer (gradient transformation), so that it ignores # extra args (i.e. `precond_state` for preconditioner) if not needed. self._optax_optimizer = optax.with_extra_args_support( @@ -542,13 +583,16 @@ def __init__( ) if self._preconditioner is not None: + if not isinstance(self._preconditioner, Preconditioner): raise ValueError( "preconditioner must be a {}, but {} is given.".format( Preconditioner, type(self._preconditioner) ) ) + preconditioner: Preconditioner = self._preconditioner + def _init_preconditioner( params: Params, rng: PRNGKey, @@ -556,7 +600,9 @@ def _init_preconditioner( func_state: Optional[FuncState] = None, ) -> PreconditionState: """Maybe initializes the PreconditionState.""" + batch = self._batch_process_func(batch) + func_args = kfac_jax.optimizer.make_func_args( params, func_state, @@ -565,6 +611,7 @@ def _init_preconditioner( has_state=self._value_func_has_state, has_rng=self._value_func_has_rng, ) + return preconditioner.init(func_args, rng) self._pmap_init_preconditioner = jax.pmap( @@ -597,8 +644,8 @@ def _step( """A single step of optax.""" rng_func, rng_precon = jax.random.split(rng) - batch = self._batch_process_func(batch) + func_args = kfac_jax.optimizer.make_func_args( params, func_state, rng_func, batch, has_state=self._value_func_has_state, @@ -606,6 +653,7 @@ def _step( ) optax_state, precond_state = state.optax_state, state.precond_state + if self._preconditioner is not None: precond_state = self._preconditioner.maybe_update( precond_state, @@ -613,15 +661,19 @@ def _step( rng_precon, ) precond_state = self._preconditioner.increment_count(precond_state) + out, grads = self._value_and_grad_func(*func_args) + loss, new_func_state, stats = kfac_jax.optimizer.extract_func_outputs( out, has_aux=self._value_func_has_aux, has_state=self._value_func_has_state, ) + loss, stats, grads = kfac_jax.utils.pmean_if_pmap( # pytype: disable=wrong-keyword-args (loss, stats, grads), axis_name=self.pmap_axis_name ) + stats = stats or {} stats["loss"] = loss @@ -641,21 +693,27 @@ def _step( stats["batch_size"] = batch_size * jax.device_count() stats["data_seen"] = stats["step"] * stats["batch_size"] stats["learning_rate"] = self._learning_rate(global_step_int) + if self._include_norms_in_stats: stats["grad_norm"] = kfac_jax.utils.norm(grads) stats["update_norm"] = kfac_jax.utils.norm(updates) stats["param_norm"] = kfac_jax.utils.norm(params) stats["rel_grad_norm"] = stats["grad_norm"] / stats["param_norm"] stats["rel_update_norm"] = stats["update_norm"] / stats["param_norm"] + if self._include_per_param_norms_in_stats: stats.update(kfac_jax.utils.per_parameter_norm(grads, "grad_norm")) stats.update(kfac_jax.utils.per_parameter_norm(updates, "update_norm")) param_norms = kfac_jax.utils.per_parameter_norm(params, "param_norm") + for key in param_norms: + norm = param_norms[key] stats[key] = norm + grad_key = key.replace("param", "grad") stats["rel_" + grad_key] = stats[grad_key] / norm + upd_key = key.replace("param", "update") stats["rel_" + upd_key] = stats[upd_key] / norm @@ -709,18 +767,27 @@ def tf1_rmsprop( def tf1_scale_by_rms(decay_=0.9, epsilon_=1e-8): """Same as optax.scale_by_rms, but initializes second moment to one.""" + def init_fn(params): nu = jax.tree_util.tree_map(jnp.ones_like, params) # second moment return optax.ScaleByRmsState(nu=nu) + def _update_moment(updates, moments, decay, order): + return jax.tree_util.tree_map( lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) + def update_fn(updates, state, params=None): + del params + nu = _update_moment(updates, state.nu, decay_, 2) + updates = jax.tree_util.tree_map( lambda g, n: g / (jnp.sqrt(n + epsilon_)), updates, nu) + return updates, optax.ScaleByRmsState(nu=nu) + return optax.GradientTransformation(init_fn, update_fn) return optax.chain( @@ -735,26 +802,34 @@ def linear_interpolation( interpolation_points: Tuple[Tuple[float, float], ...] ) -> Array: """Performs linear interpolation between the interpolation points.""" + xs, ys = zip(*interpolation_points) masks = [x < ci for ci in xs[1:]] + min_iter = jnp.zeros_like(x) max_iter = jnp.zeros_like(x) max_val = jnp.zeros_like(x) min_val = jnp.zeros_like(x) p = jnp.ones_like(x) + for i in range(len(masks) - 1): pi = p * masks[i] + min_iter = pi * xs[i] + (1 - pi) * min_iter max_iter = pi * xs[i + 1] + (1 - pi) * max_iter max_val = pi * ys[i] + (1 - pi) * max_val min_val = pi * ys[i + 1] + (1 - pi) * min_val + p = p * (1 - masks[i]) + min_iter = p * xs[-2] + (1 - p) * min_iter max_iter = p * xs[-1] + (1 - p) * max_iter max_val = p * ys[-2] + (1 - p) * max_val min_val = p * ys[-1] + (1 - p) * min_val + diff = (min_val - max_val) progress = (x - min_iter) / (max_iter - min_iter - 1) + return max_val + diff * jnp.minimum(progress, 1.0) @@ -772,12 +847,16 @@ def imagenet_sgd_schedule( # Can be found in Section 5.1 of https://arxiv.org/pdf/1706.02677.pdf steps_per_epoch = dataset_size / train_total_batch_size current_epoch = global_step / steps_per_epoch + lr = (0.1 * train_total_batch_size) / 256 lr_linear_till = 5 + boundaries = jnp.array((30, 60, 80)) * steps_per_epoch values = jnp.array([1., 0.1, 0.01, 0.001]) * lr + index = jnp.sum(boundaries < global_step) lr = jnp.take(values, index) + return lr * jnp.minimum(1., current_epoch / lr_linear_till) @@ -795,6 +874,7 @@ def kfac_resnet50_schedule( **_: Any, ) -> Array: """Custom schedule for KFAC.""" + return jnp.power(10.0, linear_interpolation( x=global_step, interpolation_points=( @@ -1033,6 +1113,7 @@ def construct_schedule( **kwargs, ) -> Callable[[Numeric], Array]: """Constructs the actual schedule from its name and extra kwargs.""" + if name == "fixed": return functools.partial(fixed_schedule, **kwargs) elif name == "imagenet_sgd": @@ -1053,16 +1134,21 @@ def kfac_bn_registration_kwargs(bn_registration: str) -> Mapping[ str, Union[Tuple[str, ...], Mapping[str, Type[kfac_jax.CurvatureBlock]]] ]: """Constructs KFAC kwargs for the given batch-norm registration strategy.""" + if bn_registration == "generic": return dict(patterns_to_skip=("scale_and_shift", "scale_only")) + elif bn_registration == "full": + return dict( layer_tag_to_block_cls=dict( scale_and_shift_tag=kfac_jax.ScaleAndShiftFull, ) ) + elif bn_registration != "diag": raise ValueError(f"Unknown batch_norm_registration={bn_registration}.") + return {} @@ -1129,6 +1215,7 @@ def create_optimizer( **kwargs.pop("learning_rate_schedule") ) optax_ctor = lambda lr: (getattr(optax, name)(learning_rate=lr, **kwargs)) + return OptaxWrapper( value_and_grad_func=value_and_grad_func, value_func_has_aux=has_aux, diff --git a/kfac_jax/_src/curvature_blocks.py b/kfac_jax/_src/curvature_blocks.py index 9e69727..2be716c 100644 --- a/kfac_jax/_src/curvature_blocks.py +++ b/kfac_jax/_src/curvature_blocks.py @@ -254,6 +254,12 @@ def scale(self, state: "CurvatureBlock.State", use_cache: bool) -> Numeric: Returns: A scalar value to be multiplied with any unscaled block representation. """ + + # TODO(jamesmartens,botev): This way of handling state dependent scale is + # a bit hacky and leads to complexity in other parts of the code that must + # be aware of how this part works. Should try to replace this with something + # better. + if use_cache: return self.fixed_scale() @@ -365,7 +371,9 @@ def multiply_matpower( Returns: A tuple of arrays, representing the result of the matrix-vector product. """ + scale = self.scale(state, use_cached) + result = self._multiply_matpower_unscaled( state=state, vector=vector, @@ -541,6 +549,19 @@ def to_dense_matrix(self, state: "CurvatureBlock.State") -> Array: def _to_dense_unscaled(self, state: "CurvatureBlock.State") -> Array: """A dense representation of the curvature, ignoring ``self.scale``.""" + def norm(self, state: "CurvatureBlock.State", norm_type: str) -> Array: + """Computes the norm of the curvature block, according to ``norm_type``.""" + + return self.scale(state, False) * self._norm_unscaled(state, norm_type) + + @abc.abstractmethod + def _norm_unscaled( + self, + state: "CurvatureBlock.State", + norm_type: str + ) -> Array: + """Like ``norm`` but with ``self.scale`` not included.""" + class ScaledIdentity(CurvatureBlock): """A block that assumes that the curvature is a scaled identity matrix.""" @@ -596,9 +617,13 @@ def _multiply_matpower_unscaled( use_cached: bool, ) -> Tuple[Array, ...]: - del exact_power, use_cached # Unused + del exact_power # Unused - identity_weight = identity_weight + 1.0 + # state_dependent_scale needs to be included because it won't be by the + # caller of this function (multiply_matpower) when use_cached=True + scale = self.state_dependent_scale(state) if use_cached else 1.0 + + identity_weight = identity_weight + scale if power == 1: return jax.tree_util.tree_map(lambda x: identity_weight * x, vector) @@ -644,6 +669,14 @@ def _to_dense_unscaled(self, state: CurvatureBlock.State) -> Array: del state # not used return jnp.eye(self.dim) + def _norm_unscaled( + self, + state: CurvatureBlock.State, + norm_type: str + ) -> Array: + + return utils.psd_matrix_norm(jnp.ones([self.dim]), norm_type=norm_type) + class Diagonal(CurvatureBlock, abc.ABC): """An abstract class for approximating only the diagonal of curvature.""" @@ -701,7 +734,12 @@ def _multiply_matpower_unscaled( use_cached: bool, ) -> Tuple[Array, ...]: - factors = tuple(f.value + identity_weight for f in state.diagonal_factors) + # state_dependent_scale needs to be included because it won't be by the + # caller of this function (multiply_matpower) when use_cached=True + scale = self.state_dependent_scale(state) if use_cached else 1.0 + + factors = tuple(scale * f.value + identity_weight + for f in state.diagonal_factors) assert len(factors) == len(vector) @@ -728,6 +766,7 @@ def _update_cache( approx_powers: Set[Scalar], eigenvalues: bool, ) -> "Diagonal.State": + return state.copy() def _to_dense_unscaled(self, state: "Diagonal.State") -> Array: @@ -739,6 +778,16 @@ def _to_dense_unscaled(self, state: "Diagonal.State") -> Array: # Construct diagonal matrix return jnp.diag(jnp.concatenate(factors, axis=0)) + def _norm_unscaled( + self, + state: CurvatureBlock.State, + norm_type: str + ) -> Array: + + return utils.product( + utils.psd_matrix_norm(f.value.flatten(), norm_type=norm_type) + for f in state.diagonal_factors) + class Full(CurvatureBlock, abc.ABC): """An abstract class for approximating the block matrix with a full matrix.""" @@ -776,6 +825,7 @@ def __init__( if eigen_decomposition_threshold is None: threshold = get_default_eigen_decomposition_threshold() self._eigen_decomposition_threshold = threshold + else: self._eigen_decomposition_threshold = eigen_decomposition_threshold @@ -788,10 +838,12 @@ def parameters_list_to_single_vector( """Converts values corresponding to parameters of the block to vector.""" if len(parameters_shaped_list) != self.number_of_parameters: + raise ValueError(f"Expected a list of {self.number_of_parameters} values," f" but got {len(parameters_shaped_list)} instead.") for array, shape in zip(parameters_shaped_list, self.parameters_shapes): + if array.shape != shape: raise ValueError(f"Expected a value of shape {shape}, but got " f"{array.shape} instead.") @@ -815,6 +867,7 @@ def single_vector_to_parameters_list( index = 0 for shape in self.parameters_shapes: + size = utils.product(shape) parameters_shaped_list.append(vector[index: index + size].reshape(shape)) index += size @@ -880,7 +933,17 @@ def _multiply_matpower_unscaled( vector = self.parameters_list_to_single_vector(vector) if power == 1: - result = jnp.matmul(state.matrix.value, vector) + identity_weight * vector + + result = jnp.matmul(state.matrix.value, vector) + + if use_cached: + # state_dependent_scale needs to be included here because it won't be by + # the caller of this function (multiply_matpower) when use_cached=True. + # This is not an issue for other powers because they bake in + # state_dependent_scale. + result *= self.state_dependent_scale(state) + + result += identity_weight * vector elif not use_cached: @@ -911,8 +974,10 @@ def _eigenvalues_unscaled( state: "Full.State", use_cached: bool, ) -> Array: + if not use_cached: return utils.safe_psd_eigh(state.matrix.value)[0] + else: return state.cache["eigenvalues"] @@ -957,6 +1022,7 @@ def _update_cache( return state def _to_dense_unscaled(self, state: "Full.State") -> Array: + # Permute the matrix according to the parameters canonical order return utils.block_permuted( state.matrix.value, @@ -964,6 +1030,14 @@ def _to_dense_unscaled(self, state: "Full.State") -> Array: block_order=self.parameters_canonical_order ) + def _norm_unscaled( + self, + state: CurvatureBlock.State, + norm_type: str + ) -> Array: + + return utils.psd_matrix_norm(state.matrix.value, norm_type=norm_type) + class KroneckerFactored(CurvatureBlock, abc.ABC): """An abstract class for approximating the block with a Kronecker product.""" @@ -1070,6 +1144,7 @@ def _init( approx_powers_to_cache: Set[Scalar], cache_eigenvalues: bool, ) -> "KroneckerFactored.State": + cache = {} factors = [] @@ -1085,10 +1160,12 @@ def _init( cache[f"{i}_factor_eigen_vectors"] = jnp.zeros((d, d), dtype=self.dtype) for power in approx_powers_to_cache: + if power != -1: raise NotImplementedError( f"Approximations for power {power} is not yet implemented." ) + if str(power) not in cache: cache[str(power)] = {} @@ -1122,6 +1199,7 @@ def _multiply_matpower_unscaled( exact_power: bool, use_cached: bool, ) -> Tuple[Array, ...]: + assert len(state.factors) == len(self.axis_groups) vector = self.parameter_shaped_list_to_grouped_array(vector) @@ -1130,8 +1208,14 @@ def _multiply_matpower_unscaled( factors = [f.value for f in state.factors] + # state_dependent_scale needs to be included here because it won't be by + # the caller of this function (multiply_matpower) when use_cached=True. + # This is not an issue for other powers because they bake in + # state_dependent_scale. + scale = self.state_dependent_scale(state) if use_cached else 1.0 + if exact_power: - result = utils.kronecker_product_axis_mul_v(factors, vector) + result = scale * utils.kronecker_product_axis_mul_v(factors, vector) result = result + identity_weight * vector else: @@ -1139,9 +1223,9 @@ def _multiply_matpower_unscaled( # norm in its computation, it might make sense to cache it. But we # currently don't do that. - result = utils.kronecker_product_axis_mul_v( - utils.pi_adjusted_kronecker_factors(*factors, - damping=identity_weight), + result = scale * utils.kronecker_product_axis_mul_v( + utils.pi_adjusted_kronecker_factors( + *factors, damping=identity_weight / scale), vector) elif exact_power: @@ -1174,6 +1258,7 @@ def _multiply_matpower_unscaled( ) if use_cached: + factors = [ state.cache[str(power)][f"{i}_factor"] for i in range(len(state.factors)) @@ -1215,6 +1300,7 @@ def _update_cache( # pytype: disable=signature-mismatch # numpy-scalars approx_powers: Numeric, eigenvalues: bool, ) -> "KroneckerFactored.State": + assert len(state.factors) == len(self.axis_groups) # Copy this first since we mutate it later in this function. @@ -1224,8 +1310,11 @@ def _update_cache( # pytype: disable=signature-mismatch # numpy-scalars factor_scale = jnp.power(scale, 1.0 / len(self.axis_groups)) if eigenvalues or exact_powers: + s_q = [utils.safe_psd_eigh(factor.value) for factor in state.factors] + s, q = zip(*s_q) + for i in range(len(state.factors)): state.cache[f"{i}_factor_eigenvalues"] = factor_scale * s[i] @@ -1233,24 +1322,36 @@ def _update_cache( # pytype: disable=signature-mismatch # numpy-scalars state.cache[f"{i}_factor_eigen_vectors"] = q[i] for power in approx_powers: + if power != -1: raise NotImplementedError( f"Approximations for power {power} is not yet implemented." ) cache = state.cache[str(power)] + # This computes the approximate inverse factors using the generalization # of the pi-adjusted inversion from the original KFAC paper. - inv_factors = utils.pi_adjusted_kronecker_inverse( *[factor.value for factor in state.factors], damping=identity_weight, ) + for i in range(len(state.factors)): cache[f"{i}_factor"] = inv_factors[i] / factor_scale return state + def _norm_unscaled( + self, + state: CurvatureBlock.State, + norm_type: str + ) -> Array: + + return utils.product( + utils.psd_matrix_norm(f.value, norm_type=norm_type) + for f in state.factors) + class TwoKroneckerFactored(KroneckerFactored): """A Kronecker factored block for layers with weights and an optional bias.""" @@ -1271,29 +1372,35 @@ def parameters_shaped_list_to_array( self, parameters_shaped_list: Sequence[Array], ) -> Array: + for p, s in zip(parameters_shaped_list, self.parameters_shapes): assert p.shape == s if self.has_bias: w, b = parameters_shaped_list return jnp.concatenate([w.reshape([-1, w.shape[-1]]), b[None]], axis=0) + else: # This correctly reshapes the parameters of both dense and conv2d blocks [w] = parameters_shaped_list return w.reshape([-1, w.shape[-1]]) def array_to_parameters_shaped_list(self, array: Array) -> Tuple[Array, ...]: + if self.has_bias: w, b = array[:-1], array[-1] return w.reshape(self.parameters_shapes[0]), b + else: return tuple([array.reshape(self.parameters_shapes[0])]) def _to_dense_unscaled(self, state: "KroneckerFactored.State") -> Array: + assert 0 < self.number_of_parameters <= 2 inputs_factor = state.factors[0].value if self.has_bias and self.parameters_canonical_order[0] != 0: + # Permute the matrix according to the parameters canonical order inputs_factor = utils.block_permuted( state.factors[0].value, @@ -1327,6 +1434,7 @@ def update_curvature_matrix_estimate( for factor, dw in zip(state.diagonal_factors, estimation_data["params_tangent"]): + factor.update(dw * dw / batch_size, ema_old, ema_new) return state diff --git a/kfac_jax/_src/curvature_estimator.py b/kfac_jax/_src/curvature_estimator.py index d6b458f..6bbcc75 100644 --- a/kfac_jax/_src/curvature_estimator.py +++ b/kfac_jax/_src/curvature_estimator.py @@ -546,8 +546,8 @@ class CurvatureEstimator(Generic[StateType], utils.Finalizable): The cached values are only updated once you call the method :func:`~CurvatureEstimator.update_cache`. Multiple methods contain the keyword argument ``use_cached`` which specify whether you want to compute the - corresponding expression using the current curvature estimate or used a cached - version. + corresponding expression using the current curvature estimate or using a + cached version. Attributes: func: The model evaluation function. @@ -643,6 +643,7 @@ def multiply_matpower( exact_power: bool, use_cached: bool, pmap_axis_name: Optional[str], + norm_to_scale_identity_weight_per_block: Optional[str] = None, ) -> utils.Params: """Computes ``(CurvatureMatrix + identity_weight I)**power`` times ``vector``. @@ -665,6 +666,9 @@ def multiply_matpower( pmap_axis_name: The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion. + norm_to_scale_identity_weight_per_block: The name of a norm to use to + compute extra per-block scaling for identity_weight. See + psd_matrix_norm() in utils/math.py for the definition of these. Returns: A parameter structured vector containing the product. @@ -678,6 +682,7 @@ def multiply( exact_power: bool, use_cached: bool, pmap_axis_name: Optional[str], + norm_to_scale_identity_weight_per_block: Optional[str] = None, ) -> utils.Params: """Computes ``(CurvatureMatrix + identity_weight I)`` times ``vector``.""" @@ -688,7 +693,8 @@ def multiply( power=1, exact_power=exact_power, use_cached=use_cached, - pmap_axis_name=pmap_axis_name + pmap_axis_name=pmap_axis_name, + norm_to_scale_identity_weight_per_block=norm_to_scale_identity_weight_per_block, ) def multiply_inverse( @@ -699,6 +705,7 @@ def multiply_inverse( exact_power: bool, use_cached: bool, pmap_axis_name: Optional[str], + norm_to_scale_identity_weight_per_block: Optional[str] = None, ) -> utils.Params: """Computes ``(CurvatureMatrix + identity_weight I)^-1`` times ``vector``.""" @@ -709,7 +716,8 @@ def multiply_inverse( power=-1, exact_power=exact_power, use_cached=use_cached, - pmap_axis_name=pmap_axis_name + pmap_axis_name=pmap_axis_name, + norm_to_scale_identity_weight_per_block=norm_to_scale_identity_weight_per_block, ) @abc.abstractmethod @@ -896,6 +904,7 @@ def __init__( registration function. """ super().__init__(func, params_index, default_estimation_mode) + self._index_to_block_ctor = index_to_block_ctor or dict() self._layer_tag_to_block_ctor = layer_tag_to_block_ctor or dict() self._auto_register_tags = auto_register_tags @@ -906,6 +915,7 @@ def __init__( auto_register_tags=auto_register_tags, **auto_register_kwargs ) + # Initialized during finalization self._jaxpr: Optional[tracer.ProcessedJaxpr] = None self._blocks: Optional[Tuple[curvature_blocks.CurvatureBlock]] = None @@ -923,6 +933,7 @@ def _check_finalized(self): def _create_blocks(self): """Creates all the curvature blocks instances in ``self._blocks``.""" + assert self._jaxpr is not None blocks_list = [] @@ -1115,9 +1126,12 @@ def _sync_state( state: "BlockDiagonalCurvature.State", pmap_axis_name: Optional[str], ) -> "BlockDiagonalCurvature.State": + block_states = [] + for block, block_state in zip(self.blocks, state.blocks_states): block_states.append(block.sync(block_state.copy(), pmap_axis_name)) + return BlockDiagonalCurvature.State( synced=jnp.asarray(True), blocks_states=tuple(block_states), @@ -1129,6 +1143,7 @@ def sync( state: "BlockDiagonalCurvature.State", pmap_axis_name: Optional[str], ) -> "BlockDiagonalCurvature.State": + return jax.lax.cond( state.synced, lambda s: s, @@ -1146,6 +1161,7 @@ def multiply_matpower( exact_power: bool, use_cached: bool, pmap_axis_name: Optional[str], + norm_to_scale_identity_weight_per_block: Optional[str] = None, ) -> utils.Params: blocks_vectors = self.params_vector_to_blocks_vectors( @@ -1153,21 +1169,35 @@ def multiply_matpower( identity_weight = utils.to_tuple_or_repeat(identity_weight, self.num_blocks) + def make_thunk(block, block_state, block_vector, block_identity_weight): + + def thunk(): + + weight = block_identity_weight + + if (norm_to_scale_identity_weight_per_block is not None + and norm_to_scale_identity_weight_per_block != "none"): + + weight *= block.norm( + block_state, norm_to_scale_identity_weight_per_block) + + return block.multiply_matpower( + state=block_state, + vector=block_vector, + identity_weight=weight, + power=power, + exact_power=exact_power, + use_cached=use_cached, + ) + + return thunk + thunks = [] for block, block_state, block_vector, block_identity_weight in zip( self.blocks, state.blocks_states, blocks_vectors, identity_weight): thunks.append( - functools.partial( - block.multiply_matpower, - state=block_state, - vector=block_vector, - identity_weight=block_identity_weight, - power=power, - exact_power=exact_power, - use_cached=use_cached, - ) - ) + make_thunk(block, block_state, block_vector, block_identity_weight)) if self._distributed_multiplies and pmap_axis_name is not None: @@ -1412,23 +1442,39 @@ def update_cache( approx_powers: Optional[curvature_blocks.ScalarOrSequence], eigenvalues: bool, pmap_axis_name: Optional[str], + norm_to_scale_identity_weight_per_block: Optional[str] = None, ) -> "BlockDiagonalCurvature.State": + identity_weight = utils.to_tuple_or_repeat(identity_weight, self.num_blocks) + def make_thunk(block, block_state, block_identity_weight): + + def thunk(): + + weight = block_identity_weight + + if (norm_to_scale_identity_weight_per_block is not None + and norm_to_scale_identity_weight_per_block != "none"): + + weight *= block.norm( + block_state, norm_to_scale_identity_weight_per_block) + + return block.update_cache( + state=block_state, + identity_weight=block_identity_weight, + exact_powers=exact_powers, + approx_powers=approx_powers, + eigenvalues=eigenvalues, + ) + + return thunk + thunks = [] for block, block_state, block_identity_weight in zip(self.blocks, state.blocks_states, identity_weight): - thunks.append( - functools.partial( - block.update_cache, - state=block_state, - identity_weight=block_identity_weight, - exact_powers=exact_powers, - approx_powers=approx_powers, - eigenvalues=eigenvalues, - ) - ) + + thunks.append(make_thunk(block, block_state, block_identity_weight)) if self._distributed_cache_updates and pmap_axis_name is not None: diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index b17d2f3..132cea4 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -141,6 +141,7 @@ def __init__( distributed_inverses: bool = True, num_estimator_samples: int = 1, should_vmap_estimator_samples: bool = False, + norm_to_scale_identity_weight_per_block: Optional[str] = None, ): """Initializes the K-FAC optimizer with the provided settings. @@ -351,6 +352,11 @@ def __init__( '[fisher,ggn]_curvature_prop'``. (Default: 1) should_vmap_estimator_samples: Whether to use ``jax.vmap`` to compute samples when ``num_estimator_samples > 1``. (Default: False) + norm_to_scale_identity_weight_per_block: The name of a norm to use to + compute extra per-block scaling for the damping. See psd_matrix_norm() + in utils/math.py for the definition of these. Note that this will not + effect the exact quadratic model that is used as part of the "adaptive" + learning rate, momentum, and damping methods. (Default: None) """ super().__init__( @@ -445,6 +451,16 @@ def schedule_with_first_step_zero( self._use_cached_inverses = (self._inverse_update_period != 1) self._use_exact_inverses = use_exact_inverses + self._norm_to_scale_identity_weight_per_block = ( + norm_to_scale_identity_weight_per_block + ) + + if (norm_to_scale_identity_weight_per_block is not None + and norm_to_scale_identity_weight_per_block != "none"): + + assert (not use_adaptive_learning_rate and not use_adaptive_momentum + and not use_adaptive_damping) # not currently supported + # Curvature estimator self._estimator = curvature_estimator.BlockDiagonalCurvature( func=self._value_func, @@ -784,6 +800,7 @@ def _compute_preconditioned_gradient( exact_power=self._use_exact_inverses, use_cached=self._use_cached_inverses, pmap_axis_name=self.pmap_axis_name, + norm_to_scale_identity_weight_per_block=self._norm_to_scale_identity_weight_per_block, ) if self._norm_constraint is not None: @@ -1276,6 +1293,7 @@ def c_times_v(v): exact_power=True, use_cached=False, pmap_axis_name=self.pmap_axis_name, + norm_to_scale_identity_weight_per_block=self._norm_to_scale_identity_weight_per_block, ) c_vectors = [c_times_v(v_i) for v_i in vectors] diff --git a/kfac_jax/_src/utils/math.py b/kfac_jax/_src/utils/math.py index 2684084..1c49a65 100644 --- a/kfac_jax/_src/utils/math.py +++ b/kfac_jax/_src/utils/math.py @@ -342,18 +342,21 @@ def psd_inv_cholesky(matrix: Array) -> Array: def psd_matrix_norm( matrix: Array, - norm_type: str = "avg_trace", + norm_type: str = "avg_diag", method_2norm: str = "lobpcg", rng_key: Optional[PRNGKey] = None -) -> Array: +) -> Numeric: """Computes one of several different matrix norms for PSD matrices. + NOTE: not all the functions options provided here are actually norms, but most + are. + Args: matrix: a square matrix represented as a 2D array, a 1D vector giving the diagonal, or a 0D scalar (which gets interpreted as a 1x1 matrix). Must be positive semi-definite (PSD). norm_type: a string specifying the type of matrix norm. Can be "2_norm" for - the matrix 2-norm aka the spectral norm, "avg_trace" for the average of + the matrix 2-norm aka the spectral norm, "avg_diag" for the average of diagonal entries, "1_norm" for the matrix 1-norm, or "avg_fro" for the Frobenius norm divided by the square root of the number of rows. method_2norm: a string specifying the method used to compute 2-norms. Can @@ -386,6 +389,7 @@ def psd_matrix_norm( matrix, v, m=300, tol=1e-8)[0][0] elif method_2norm == "power_iteration": + return optax.power_iteration( matrix, num_iters=300, error_tolerance=1e-7)[1] @@ -395,7 +399,7 @@ def psd_matrix_norm( else: raise ValueError(f"Unsupported shape for factor array: {matrix.shape}") - elif norm_type == "avg_trace": + elif norm_type == "avg_diag": if matrix.ndim == 0: return matrix @@ -409,7 +413,103 @@ def psd_matrix_norm( else: raise ValueError(f"Unsupported shape for factor array: {matrix.shape}") - elif norm_type == "1_norm": + elif norm_type == "median_diag": + + if matrix.ndim == 0: + return matrix + + elif matrix.ndim == 1: + return jnp.median(matrix) + + elif matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1]: + return jnp.median(jnp.diag(matrix)) + + else: + raise ValueError(f"Unsupported shape for factor array: {matrix.shape}") + + elif norm_type == "trace": + + if matrix.ndim == 0: + return matrix + + elif matrix.ndim == 1: + return jnp.sum(matrix) + + elif matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1]: + return jnp.trace(matrix) + + else: + raise ValueError(f"Unsupported shape for factor array: {matrix.shape}") + + elif norm_type == "median_eig": + + if matrix.ndim == 0: + return matrix + + elif matrix.ndim == 1: + return jnp.median(matrix) + + elif matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1]: + # call safe_psd_eigh instead? + s, _ = jnp.linalg.eigh(matrix) + return jnp.median(s) + + else: + raise ValueError(f"Unsupported shape for factor array: {matrix.shape}") + + elif norm_type == "median_eig_approx": + + if matrix.ndim == 0: + return matrix + + elif matrix.ndim == 1: + return jnp.median(matrix) + + elif matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1]: + + rng_key = jax.random.PRNGKey(123) + + v = jax.random.normal( + rng_key, shape=[matrix.shape[0], 64]) + + y = matrix @ v + + s_samp = jnp.sqrt(jnp.sum(v*y, axis=0)) / jnp.sqrt( + jnp.sum(v**2, axis=0)) + + return jnp.median(s_samp) + + elif norm_type == "one_over_dim": # this isn't a norm + + if matrix.ndim == 0: + return 1.0 + + elif matrix.ndim == 1: + return 1.0 / matrix.shape[0] + + elif matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1]: + return 1.0 / matrix.shape[0] + + else: + raise ValueError(f"Unsupported shape for factor array: {matrix.shape}") + + elif norm_type == "5th_eig": + + if matrix.ndim == 0: + return matrix + + elif matrix.ndim == 1: + return jnp.sort(matrix)[-5] + + elif matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1]: + # call safe_psd_eigh instead? + s, _ = jnp.linalg.eigh(matrix) + return jnp.sort(s)[-5] + + else: + raise ValueError(f"Unsupported shape for factor array: {matrix.shape}") + + elif norm_type == "1_norm": # equiv to inf norm for symmetric matrices if matrix.ndim == 0: return matrix @@ -437,6 +537,20 @@ def psd_matrix_norm( else: raise ValueError(f"Unsupported shape for factor array: {matrix.shape}") + elif norm_type == "fro": + + if matrix.ndim == 0: + return matrix + + elif matrix.ndim == 1: + return jnp.linalg.norm(matrix) + + elif matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1]: + return jnp.linalg.norm(matrix) + + else: + raise ValueError(f"Unsupported shape for factor array: {matrix.shape}") + else: raise ValueError(f"Unrecognized norm type: '{norm_type}'") @@ -475,15 +589,15 @@ def pi_adjusted_kronecker_factors( # scalar factors `c_i` into a single overall scaling coefficient and # distribute the damping to each single non-scalar factor `u_i` equally. - norm_type = "avg_trace" + norm_type = "avg_diag" norms = [psd_matrix_norm(f, norm_type=norm_type) for f in factors] # Compute the normalized factors `u_i`, such that Trace(u_i) / dim(u_i) = 1 us = [fi / ni for fi, ni in zip(factors, norms)] - # kron(arrays) = c * kron(us) - + # Compute the overall norm for the whole Kronecker product. We should have + # kron(arrays) == c * kron(us). c = jnp.prod(jnp.array(norms)) damping = damping.astype(c.dtype) # pytype: disable=attribute-error # numpy-scalars @@ -508,10 +622,10 @@ def regular_case() -> Tuple[Array, ...]: for u in us: - if u.size == 1: + if u.size == 1: # scalar case u_hat = jnp.ones_like(u) # damping not used in the scalar factors - elif u.ndim == 2: + elif u.ndim == 2: # matrix case u_hat = u + d_hat * jnp.eye(u.shape[0], dtype=u.dtype) else: # diagonal case