diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 70659a782..c23abcceb 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -38,7 +38,7 @@ class TraceState(NamedTuple): def trace( - decay: float, + decay: Union[float, jax.Array], nesterov: bool = False, accumulator_dtype: Optional[Any] = None, ) -> base.GradientTransformation: @@ -134,7 +134,7 @@ class EmaState(NamedTuple): def ema( - decay: float, + decay: Union[float, jax.Array], debias: bool = True, accumulator_dtype: Optional[Any] = None ) -> base.GradientTransformation: @@ -180,8 +180,8 @@ class ScaleByRssState(NamedTuple): def scale_by_rss( - initial_accumulator_value: float = 0.1, - eps: float = 1e-7 + initial_accumulator_value: Union[float, jax.Array] = 0.1, + eps: Union[float, jax.Array] = 1e-7 ) -> base.GradientTransformation: """Rescale updates by the root of the sum of all squared gradients to date. @@ -221,9 +221,9 @@ class ScaleByRmsState(NamedTuple): def scale_by_rms( - decay: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0. + decay: Union[float, jax.Array] = 0.9, + eps: Union[float, jax.Array] = 1e-8, + initial_scale: Union[float, jax.Array] = 0. ) -> base.GradientTransformation: """Rescale updates by the root of the exp. moving avg of the square. @@ -261,9 +261,9 @@ class ScaleByRStdDevState(NamedTuple): def scale_by_stddev( - decay: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0. + decay: Union[float, jax.Array] = 0.9, + eps: Union[float, jax.Array] = 1e-8, + initial_scale: Union[float, jax.Array] = 0. ) -> base.GradientTransformation: """Rescale updates by the root of the centered exp. moving average of squares. @@ -305,10 +305,10 @@ class ScaleByAdamState(NamedTuple): def scale_by_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, + b1: Union[float, jax.Array] = 0.9, + b2: Union[float, jax.Array] = 0.999, + eps: Union[float, jax.Array] = 1e-8, + eps_root: Union[float, jax.Array] = 0.0, mu_dtype: Optional[chex.ArrayDType] = None, ) -> base.GradientTransformation: """Rescale updates according to the Adam algorithm. @@ -361,10 +361,10 @@ class ScaleByAmsgradState(NamedTuple): def scale_by_amsgrad( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, + b1: Union[float, jax.Array] = 0.9, + b2: Union[float, jax.Array] = 0.999, + eps: Union[float, jax.Array] = 1e-8, + eps_root: Union[float, jax.Array] = 0.0, mu_dtype: Optional[chex.ArrayDType] = None, ) -> base.GradientTransformation: """Rescale updates according to the AMSGrad algorithm. @@ -413,9 +413,9 @@ def update_fn(updates, state, params=None): def scale_by_adamax( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8 + b1: Union[float, jax.Array] = 0.9, + b2: Union[float, jax.Array] = 0.999, + eps: Union[float, jax.Array] = 1e-8 ) -> base.GradientTransformation: """Rescale updates according to the Adamax algorithm. @@ -456,8 +456,8 @@ class ScaleByLionState(NamedTuple): def scale_by_lion( - b1: float = 0.9, - b2: float = 0.99, + b1: Union[float, jax.Array] = 0.9, + b2: Union[float, jax.Array] = 0.99, mu_dtype: Optional[chex.ArrayDType] = None, ) -> base.GradientTransformation: """Rescale updates according to the Lion algorithm. @@ -498,7 +498,7 @@ def update_fn(updates, state, params=None): def scale( - step_size: float + step_size: Union[float, jax.Array] ) -> base.GradientTransformation: """Scale updates by some fixed scalar `step_size`. @@ -522,7 +522,7 @@ def update_fn(updates, state, params=None): def scale_by_param_block_norm( - min_scale: float = 1e-3 + min_scale: Union[float, jax.Array] = 1e-3 ) -> base.GradientTransformation: """Scale updates for each param block by the norm of that block's parameters. @@ -552,7 +552,7 @@ def update_fn(updates, state, params): def scale_by_param_block_rms( - min_scale: float = 1e-3 + min_scale: Union[float, jax.Array] = 1e-3 ) -> base.GradientTransformation: """Scale updates by rms of the gradient for each param vector or matrix. @@ -589,10 +589,10 @@ class ScaleByBeliefState(NamedTuple): def scale_by_belief( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-16, - eps_root: float = 1e-16 + b1: Union[float, jax.Array] = 0.9, + b2: Union[float, jax.Array] = 0.999, + eps: Union[float, jax.Array] = 1e-16, + eps_root: Union[float, jax.Array] = 1e-16 ) -> base.GradientTransformation: """Rescale updates according to the AdaBelief algorithm. @@ -634,11 +634,11 @@ def update_fn(updates, state, params=None): def scale_by_yogi( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-3, - eps_root: float = 0.0, - initial_accumulator_value: float = 1e-6 + b1: Union[float, jax.Array] = 0.9, + b2: Union[float, jax.Array] = 0.999, + eps: Union[float, jax.Array] = 1e-3, + eps_root: Union[float, jax.Array] = 0.0, + initial_accumulator_value: Union[float, jax.Array] = 1e-6 ) -> base.GradientTransformation: """Rescale updates according to the Yogi algorithm. @@ -684,11 +684,11 @@ def update_fn(updates, state, params=None): def scale_by_radam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - threshold: float = 5.0 + b1: Union[float, jax.Array] = 0.9, + b2: Union[float, jax.Array] = 0.999, + eps: Union[float, jax.Array] = 1e-8, + eps_root: Union[float, jax.Array] = 0.0, + threshold: Union[float, jax.Array] = 5.0 ) -> base.GradientTransformation: """Rescale updates according to the Rectified Adam algorithm. @@ -816,9 +816,9 @@ class ScaleByTrustRatioState(NamedTuple): def scale_by_trust_ratio( - min_norm: float = 0.0, - trust_coefficient: float = 1., - eps: float = 0., + min_norm: Union[float, jax.Array] = 0.0, + trust_coefficient: Union[float, jax.Array] = 1., + eps: Union[float, jax.Array] = 0., ) -> base.GradientTransformation: """Scale updates by `trust ratio`. @@ -870,8 +870,8 @@ class AddNoiseState(NamedTuple): def add_noise( - eta: float, - gamma: float, + eta: Union[float, jax.Array], + gamma: Union[float, jax.Array], seed: int ) -> base.GradientTransformation: """Add gradient noise. @@ -993,9 +993,9 @@ class ScaleBySM3State(NamedTuple): def scale_by_sm3( - b1: float = 0.9, - b2: float = 1.0, - eps: float = 1e-8 + b1: Union[float, jax.Array] = 0.9, + b2: Union[float, jax.Array] = 1.0, + eps: Union[float, jax.Array] = 1e-8 ) -> base.GradientTransformation: """Scale updates by `sm3`. @@ -1069,11 +1069,11 @@ class ScaleByNovogradState(NamedTuple): def scale_by_novograd( - b1: float = 0.9, - b2: float = 0.25, - eps: float = 1e-8, - eps_root: float = 0.0, - weight_decay: float = 0.0, + b1: Union[float, jax.Array] = 0.9, + b2: Union[float, jax.Array] = 0.25, + eps: Union[float, jax.Array] = 1e-8, + eps_root: Union[float, jax.Array] = 0.0, + weight_decay: Union[float, jax.Array] = 0.0, mu_dtype: Optional[chex.ArrayDType] = None, ) -> base.GradientTransformation: """Computes NovoGrad updates. @@ -1141,8 +1141,8 @@ def update_fn(updates, state, params): return base.GradientTransformation(init_fn, update_fn) -def scale_by_optimistic_gradient(alpha: float = 1.0, - beta: float = 1.0 +def scale_by_optimistic_gradient(alpha: Union[float, jax.Array] = 1.0, + beta: Union[float, jax.Array] = 1.0 ) -> base.GradientTransformation: """Compute generalized optimistic gradients.