diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 000000000..a0821bc67 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,773 @@ +Common Optimizers +=================== + +.. currentmodule:: optax + +.. autosummary:: + + adabelief + adafactor + adagrad + adam + adamw + adamax + adamaxw + amsgrad + eve + fromage + lamb + lars + noisy_sgd + novograd + optimistic_gradient_descent + dpsgd + radam + rmsprop + sgd + sm3 + yogi + + +AdaBelief +~~~~~~~~~ + +.. autofunction:: adabelief + +AdaGrad +~~~~~~~ + +.. autofunction:: adagrad + +AdaFactor +~~~~~~~~~ + +.. autofunction:: adafactor + +Adam +~~~~ + +.. autofunction:: adam + +Adamax +~~~~ + +.. autofunction:: adamax + +AdamaxW +~~~~~ + +.. autofunction:: adamaxw + +AdamW +~~~~~ + +.. autofunction:: adamw + +AMSGrad +~~~~~ + +.. autofunction:: amsgrad + +Eve +~~~ + +.. autofunction:: eve + +Fromage +~~~~~~~ + +.. autofunction:: fromage + +Lamb +~~~~ + +.. autofunction:: lamb + +Lars +~~~~ + +.. autofunction:: lars + +SM3 +~~~ + +.. autofunction:: sm3 + + +Noisy SGD +~~~~~~~~~ + +.. autofunction:: noisy_sgd + + +Novograd +~~~~~~~~~ + +.. autofunction:: novograd + + +Optimistic GD +~~~~~~~~~~~~~ + +.. autofunction:: optimistic_gradient_descent + + +Differentially Private SGD +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: dpsgd + +RAdam +~~~~~ + +.. autofunction:: radam + +RMSProp +~~~~~~~ + +.. autofunction:: rmsprop + +SGD +~~~ + +.. autofunction:: sgd + +Yogi +~~~~ + +.. autofunction:: yogi + + +Optax Transformations +===================== + + +Gradient Transforms +------------------- + +.. currentmodule:: optax + +.. autosummary:: + + adaptive_grad_clip + add_decayed_weights + add_noise + AddDecayedWeightsState + additive_weight_decay + AdditiveWeightDecayState + AddNoiseState + apply_every + ApplyEvery + bias_correction + centralize + clip + clip_by_block_rms + clip_by_global_norm + ClipByGlobalNormState + ClipState + ema + EmaState + EmptyState + FactoredState + global_norm + GradientTransformation + identity + keep_params_nonnegative + NonNegativeParamsState + OptState + Params + scale + scale_by_adam + scale_by_adamax + scale_by_amsgrad + scale_by_belief + scale_by_factored_rms + scale_by_novograd + scale_by_optimistic_gradient + scale_by_param_block_norm + scale_by_param_block_rms + scale_by_radam + scale_by_rms + scale_by_rss + scale_by_schedule + scale_by_sm3 + scale_by_stddev + scale_by_trust_ratio + scale_by_yogi + ScaleByAdamState + ScaleByAmsgradState + ScaleByNovogradState + ScaleByRmsState + ScaleByRssState + ScaleByRStdDevState + ScaleByScheduleState + ScaleByTrustRatioState + ScaleBySM3State + ScaleState + stateless + stateless_with_tree_map + set_to_zero + trace + TraceState + TransformInitFn + TransformUpdateFn + update_infinity_moment + update_moment + update_moment_per_elem_norm + Updates + zero_nans + ZeroNansState + + +Optax Types +~~~~~~~~~~~~~~ + +.. autoclass:: GradientTransformation + :members: + +.. autoclass:: TransformInitFn + :members: + +.. autoclass:: TransformUpdateFn + :members: + +.. autoclass:: OptState + :members: + +.. autoclass:: Params + :members: + +.. autoclass:: Updates + :members: + + +Optax Transforms and States +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: adaptive_grad_clip +.. autoclass:: AdaptiveGradClipState + :members: + +.. autofunction:: add_decayed_weights +.. autofunction:: add_noise +.. autoclass:: AddDecayedWeightsState + :members: + +.. autofunction:: additive_weight_decay +.. autoclass:: AdditiveWeightDecayState + :members: + +.. autoclass:: AddNoiseState + :members: + +.. autofunction:: apply_every +.. autoclass:: ApplyEvery + :members: + +.. autofunction:: centralize +.. autofunction:: clip +.. autofunction:: clip_by_block_rms +.. autofunction:: clip_by_global_norm +.. autoclass:: ClipByGlobalNormState + :members: + +.. autoclass:: ClipState + :members: + +.. autofunction:: ema +.. autoclass:: EmaState + :members: + +.. autoclass:: EmptyState + :members: + +.. autoclass:: FactoredState + :members: + +.. autofunction:: global_norm +.. autofunction:: identity +.. autofunction:: keep_params_nonnegative +.. autoclass:: NonNegativeParamsState + :members: + +.. autofunction:: scale +.. autofunction:: scale_by_adam +.. autofunction:: scale_by_adamax +.. autofunction:: scale_by_amsgrad +.. autofunction:: scale_by_belief +.. autofunction:: scale_by_eve +.. autofunction:: scale_by_factored_rms +.. autofunction:: scale_by_novograd +.. autofunction:: scale_by_param_block_norm +.. autofunction:: scale_by_param_block_rms +.. autofunction:: scale_by_radam +.. autofunction:: scale_by_rms +.. autofunction:: scale_by_rss +.. autofunction:: scale_by_schedule +.. autofunction:: scale_by_sm3 +.. autofunction:: scale_by_stddev +.. autofunction:: scale_by_trust_ratio +.. autofunction:: scale_by_yogi +.. autoclass:: ScaleByAdamState + :members: + +.. autoclass:: ScaleByAmsgradState + :members: + +.. autoclass:: ScaleByNovogradState + :members: + +.. autoclass:: ScaleByEveState + :members: + +.. autoclass:: ScaleByRmsState + :members: + +.. autoclass:: ScaleByRssState + :members: + + +.. autoclass:: ScaleByRStdDevState + :members: + +.. autoclass:: ScaleByScheduleState + :members: + +.. autoclass:: ScaleBySM3State + :members: + +.. autoclass:: ScaleByTrustRatioState + :members: + +.. autoclass:: ScaleState + :members: + +.. autofunction:: set_to_zero + +.. autofunction:: stateless +.. autofunction:: stateless_with_tree_map + +.. autofunction:: trace +.. autoclass:: TraceState + :members: + +.. autofunction:: zero_nans +.. autoclass:: ZeroNansState + :members: + + + +Apply Updates +============= + +.. autosummary:: + apply_updates + incremental_update + periodic_update + +apply_updates +~~~~~~~~~~~~~~~~~ + +.. autofunction:: apply_updates + +incremental_update +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: incremental_update + +periodic_update +~~~~~~~~~~~~~~~ + +.. autofunction:: periodic_update + + + +Combining Optimizers +===================== + +.. currentmodule:: optax + +.. autosummary:: + + chain + multi_transform + +chain +~~~~~ + +.. autofunction:: chain + + +Multi Transform +~~~~~~~~~~~~~~~ + +.. autofunction:: multi_transform +.. autoclass:: MultiTransformState + :members: + + +Optimizer Wrappers +==================== + +.. currentmodule:: optax + +.. autosummary:: + + apply_if_finite + ApplyIfFiniteState + flatten + lookahead + LookaheadParams + LookaheadState + masked + MaskedState + maybe_update + MaybeUpdateState + MultiSteps + MultiStepsState + ShouldSkipUpdateFunction + skip_large_updates + skip_not_finite + + +Apply if Finite +~~~~~~~~~~~~~~~~~ + +.. autofunction:: apply_if_finite + +.. autoclass:: ApplyIfFiniteState + :members: + + +flatten +~~~~~~~~ + +.. autofunction:: flatten + + +Lookahead +~~~~~~~~~~~~~~~~~ + +.. autofunction:: lookahead + +.. autoclass:: LookaheadParams + :members: + +.. autoclass:: LookaheadState + :members: + + +Masked Update +~~~~~~~~~~~~~~ + +.. autofunction:: masked + +.. autoclass:: MaskedState + :members: + + + +Maybe Update +~~~~~~~~~~~~~~ + +.. autofunction:: maybe_update +.. autoclass:: MaybeUpdateState + :members: + + +Multi-step Update +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MultiSteps + :members: + +.. autoclass:: MultiStepsState + :members: + + +Common Losses +=============== + +.. currentmodule:: optax + +.. autosummary:: + + cosine_distance + cosine_similarity + ctc_loss + ctc_loss_with_forward_probs + hinge_loss + huber_loss + l2_loss + log_cosh + sigmoid_binary_cross_entropy + smooth_labels + softmax_cross_entropy + softmax_cross_entropy_with_integer_labels + + +Losses +~~~~~~~ + +.. autofunction:: cosine_distance +.. autofunction:: cosine_similarity +.. autofunction:: ctc_loss +.. autofunction:: ctc_loss_with_forward_probs +.. autofunction:: hinge_loss +.. autofunction:: huber_loss +.. autofunction:: l2_loss +.. autofunction:: log_cosh +.. autofunction:: sigmoid_binary_cross_entropy +.. autofunction:: smooth_labels +.. autofunction:: softmax_cross_entropy +.. autofunction:: softmax_cross_entropy_with_integer_labels + + + +Linear Algebra Operators +======================== + +.. currentmodule:: optax + +.. autosummary:: + + matrix_inverse_pth_root + multi_normal + power_iteration + + +multi_normal +~~~~~~~~~~~~ +.. autofunction:: multi_normal + + +matrix_inverse_pth_root +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: matrix_inverse_pth_root + + +Utilities for numerical stability +================================= + +.. currentmodule:: optax + +.. autosummary:: + + safe_int32_increment + safe_norm + safe_root_mean_squares + + +Numerics +~~~~~~~~ + +.. autofunction:: safe_int32_increment +.. autofunction:: safe_norm +.. autofunction:: safe_root_mean_squares + + +power_iteration +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: power_iteration + + +Optimizer Schedules +===================== + +.. currentmodule:: optax + +.. autosummary:: + + constant_schedule + cosine_decay_schedule + cosine_onecycle_schedule + exponential_decay + join_schedules + linear_onecycle_schedule + linear_schedule + piecewise_constant_schedule + piecewise_interpolate_schedule + polynomial_schedule + sgdr_schedule + warmup_cosine_decay_schedule + warmup_exponential_decay_schedule + Schedule + InjectHyperparamsState + inject_hyperparams + +Schedules +~~~~~~~~~ + +.. autofunction:: constant_schedule +.. autofunction:: cosine_decay_schedule +.. autofunction:: cosine_onecycle_schedule +.. autofunction:: exponential_decay +.. autofunction:: join_schedules +.. autofunction:: linear_onecycle_schedule +.. autofunction:: linear_schedule +.. autofunction:: piecewise_constant_schedule +.. autofunction:: piecewise_interpolate_schedule +.. autofunction:: polynomial_schedule +.. autofunction:: sgdr_schedule +.. autofunction:: warmup_cosine_decay_schedule +.. autofunction:: warmup_exponential_decay_schedule +.. autofunction:: inject_hyperparams + +.. autoclass:: Schedule + :members: + +.. autoclass:: InjectHyperparamsState + :members: + + + +Second Order Optimization Utilities +===================================== + +.. currentmodule:: optax + +.. autosummary:: + + fisher_diag + hessian_diag + hvp + +fisher_diag +~~~~~~~~~~~ + +.. autofunction:: fisher_diag + +hessian_diag +~~~~~~~~~~~~~~~~~ + +.. autofunction:: hessian_diag + +hvp +~~~~~~~~~~~ + +.. autofunction:: hvp + + + + + + +Control Variates +================ + +.. currentmodule:: optax + +.. autosummary:: + + control_delta_method + control_variates_jacobians + moving_avg_baseline + +control_delta_method +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: control_delta_method + +control_variates_jacobians +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: control_variates_jacobians + +moving_avg_baseline +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: moving_avg_baseline + + + + +Stochastic Gradient Estimators +============================== + +.. currentmodule:: optax + +.. autosummary:: + + measure_valued_jacobians + pathwise_jacobians + score_function_jacobians + +measure_valued_jacobians +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: measure_valued_jacobians + +pathwise_jacobians +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pathwise_jacobians + +score_function_jacobians +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: score_function_jacobians + + + +Privacy-Sensitive Optax Methods +================================== + +.. currentmodule:: optax + +.. autosummary:: + + DifferentiallyPrivateAggregateState + differentially_private_aggregate + + +differentially_private_aggregate +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: differentially_private_aggregate + +.. autoclass:: DifferentiallyPrivateAggregateState + :members: + + + +General Utilities +===================================== + +.. currentmodule:: optax + +.. autosummary:: + + multi_normal + scale_gradient + +multi_normal +~~~~~~~~~~~~ + +.. autofunction:: multi_normal + +scale_gradient +~~~~~~~~~~~~~~~~~ + +.. autofunction:: scale_gradient + + +🚧 Experimental +=============== + +.. currentmodule:: optax.experimental + +.. autosummary:: + + split_real_and_imaginary + SplitRealAndImaginaryState + + +Complex-Valued Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: split_real_and_imaginary + +.. autoclass:: SplitRealAndImaginaryState + :members: diff --git a/optax/__init__.py b/optax/__init__.py index 4de165669..a41023418 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -30,6 +30,7 @@ from optax._src.alias import adamaxw from optax._src.alias import adamw from optax._src.alias import amsgrad +from optax._src.alias import eve from optax._src.alias import fromage from optax._src.alias import lamb from optax._src.alias import lars @@ -107,6 +108,7 @@ from optax._src.transform import scale_by_amsgrad from optax._src.transform import scale_by_belief from optax._src.transform import scale_by_distance_over_gradients +from optax._src.transform import scale_by_eve from optax._src.transform import scale_by_learning_rate from optax._src.transform import scale_by_lion from optax._src.transform import scale_by_novograd @@ -126,6 +128,7 @@ from optax._src.transform import ScaleByAdamState from optax._src.transform import ScaleByAmsgradState from optax._src.transform import ScaleByBeliefState +from optax._src.transform import ScaleByEveState from optax._src.transform import ScaleByLionState from optax._src.transform import ScaleByNovogradState from optax._src.transform import ScaleByRmsState @@ -254,6 +257,7 @@ "ema", "EmaState", "EmptyState", + "eve", "exponential_decay", "FactoredState", "flatten", @@ -317,6 +321,7 @@ "scale_by_adamax", "scale_by_amsgrad", "scale_by_belief", + "scale_by_eve", "scale_by_lion", "scale_by_factored_rms", "scale_by_novograd", @@ -337,6 +342,7 @@ "ScaleByAdamState", "ScaleByAmsgradState", "ScaleByBeliefState", + "ScaleByEveState", "ScaleByLionState", "ScaleByNovogradState", "ScaleByRmsState", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index ea29f1c3e..cc6de45ca 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -25,6 +25,7 @@ from optax._src import factorized from optax._src import transform from optax._src import wrappers +import optax.schedules as schedules MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]] @@ -793,6 +794,129 @@ def amsgrad( transform.scale_by_learning_rate(learning_rate), ) +def _eve( + a1: float = 1e-3, + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f: float = 1., + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """The Eve optimizer (uninjectable, see `eve()`). + + Eve is an SGD variant with adaptive global and local learning rates. + The local learning rate used for each weight is computed from estimates of + first- and second-order moments of the gradients (using suitable exponential + moving averages) as in ADAM. These are then scaled by the global learning + rate `a1`, which is adaptively modified by some notion of sub-optimality `d`: + increasing the global rate when far from optimal and decreasing it when + approaching optimality. This is also computed with exponential moving + averages, similar to the first and second moments. + + References: + Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 + + Args: + a1: this is the initial global scaling factor. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + the corresponding `GradientTransformation` + + Note: + Eve requires an additional parameter: the loss for the current iteration:: + + f := f_t + + ScaleByEveState also holds the loss from the previous iteration:: + + state.f_prev := f_{t-1} + + Since it is up to the user to inject the current loss before calling the + update function, the `eve` alias returns an injectable state by default by + wrapping `_eve` in `inject_hyperparams`. + """ + return combine.chain( + transform.scale_by_eve( + b1=b1, b2=b2, b3=b3, c=c, eps=eps, f=f, f_star=f_star, mu_dtype=mu_dtype), + _scale_by_learning_rate(a1) + ) + + +def eve( + a1: float = 1e-3, + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f: float = 1., + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Injectable Eve optimizer. + + Eve requires an additional parameter: the loss for the current iteration:: + + f := f_t + + ScaleByEveState also holds the loss from the previous iteration:: + + state.f_prev := f_{t-1} + + Since it is up to the user to inject the current loss before calling the + update function, the `eve` alias returns an injectable state by default by + wrapping `_eve` in `inject_hyperparams`. + + Args: + a1: this is the initial global scaling factor. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + the corresponding `GradientTransformation` wrapped in inject_hyperparams + + Inject the current loss as follows: + ----------------------------------- + + Initialize:: + + optimizer = optax.eve() + opt_state = optimizer.init(params) + + Train:: + + while training: + loss, grads = jax.value_and_grad(loss_fn)(params, data) + opt_state.hyperparams['f'] = loss # <-- Update state here + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + """ + return schedules.inject_hyperparams(_eve)( + a1=a1, b1=b1, b2=b2, b3=b3, c=c, eps=eps, + f=f, f_star=f_star, mu_dtype=mu_dtype + ) + def fromage( learning_rate: float, diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index f80d9b644..c1134c03f 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -38,6 +38,7 @@ dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='eve', opt_kwargs=dict(f=10)), dict(opt_name='lars', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1e-3)), dict( @@ -124,6 +125,10 @@ def test_optimization(self, opt_name, opt_kwargs, target, dtype): @jax.jit def step(params, state): updates = get_updates(params) + # TODO: double check this has to be here + if opt_name == 'eve': + f = jnp.mean(jnp.square(params-final_params)) + state.hyperparams['f'] = f # Complex gradients need to be conjugated before being added to parameters # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 updates = jax.tree_util.tree_map(lambda x: x.conj(), updates) @@ -155,6 +160,10 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( # https://github.com/deepmind/optax/issues/412. opt_inject = _inject.inject_hyperparams( opt_factory, static_args=('min_dim_size_to_factor',))(**opt_kwargs) + elif opt_name == 'eve': + # Eve is injectable by default. Reassign opt to uninjectable _eve alias + opt = alias._eve(**opt_kwargs) + opt_inject = opt_factory(**opt_kwargs) else: opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 7fc81179d..9a7036784 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -476,6 +476,77 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) +class ScaleByEveState(NamedTuple): + """State for the Eve algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + d: float + f_prev: float + + +def scale_by_eve( + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f: float = 1., + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Rescale updates according to the Eve algorithm. + + References: + [Hayashi et al, 2018](https://arxiv.org/abs/1611.01505) + + Args: + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + An (init_fn, update_fn) tuple. + """ + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByEveState( + count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=10. + ) + + + def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = utils.numerics.safe_int32_increment(state.count) + mu_hat = jax.tree_util.tree_map(lambda m: m / (1-b1), mu) + nu_hat = jax.tree_util.tree_map(lambda v: v / (1-b2), nu) + d_new = jnp.abs(f-state.f_prev) /\ + (jnp.min(jnp.array([f,state.f_prev]))-f_star) + d_tilde = jnp.clip(d_new,1/c,c) + d = jnp.where(count_inc > 1, b3*state.d + (1-b3)*d_tilde, 1.) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) + mu = utils.cast_tree(mu, mu_dtype) + return updates, ScaleByEveState( + count=count_inc, mu=mu, nu=nu, d=d, f_prev=f + ) + + return base.GradientTransformation(init_fn, update_fn) + + class ScaleByLionState(NamedTuple): """State for the Lion algorithm.""" count: chex.Array # shape=(), dtype=jnp.int32. diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index feb05ce48..4829cc414 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -45,6 +45,7 @@ def setUp(self): ('adadelta', transform.scale_by_adadelta), ('adam', transform.scale_by_adam), ('adamax', transform.scale_by_adamax), + ('eve', transform.scale_by_eve), ('lion', transform.scale_by_lion), ('rmsprop', transform.scale_by_rms), ('stddev', transform.scale_by_stddev), diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 9d0d50bc3..745cad406 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -24,7 +24,9 @@ from optax._src import base from optax._src import numerics +from optax._src import utils from optax.tree_utils import _state_utils +import optax.tree_utils as tu Array = jnp.ndarray @@ -610,3 +612,62 @@ def reject_update(_): numerics.safe_int32_increment(state.step)) return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +class EveState(NamedTuple): + """Maintains inner transform state and adds a step counter. + + Attributes: + inner_state: the state of the wrapped optimizer. + step: the counter for current step (t). + f_prev: the previous loss value. + """ + inner_state: base.OptState + step: Union[jax.Array, int] + f_prev: Union[jax.Array, float] + d_tilde_prev: Union[jax.Array, float] + + +def eve( + inner: base.GradientTransformation, + b3: float, + c: float, + f: Union[jax.Array, float], + f_star: Union[jax.Array, float] +) -> base.GradientTransformation: + """Eve optimizer. + + Args: + inner: the inner transformation. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes. + f: the current loss value. + f_star: the estimated global minimum. + + Returns: + New ``GradientTransformation``. + """ + + def init_fn(params): + return EveState( + inner_state=inner.init(params), + step=0, + f_prev=f + ) + + def update_fn(updates, state, params=None): + del params + step = utils.numerics.safe_int32_increment(state.step) + d = (jnp.abs(f - state.f_prev) / + (jnp.min(jnp.array([f, state.f_prev])) - f_star) + ) + d_hat = jnp.clip(d, 1 / c, c) + d_tilde = jnp.where(step > 1, b3 * state.d_tilde_prev + (1 - b3) * d_hat, 1.) + + new_inner_updates, new_inner_state = inner.update(updates, state.inner_state) + new_updates = tu.tree_scalar_mul(1 / d_tilde, new_inner_updates) + return new_updates, EveState(inner_state=new_inner_state, + step=step, f_prev=f) + + return base.GradientTransformation(init_fn, update_fn) + \ No newline at end of file