Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove longtime deprecated functions. #1149

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/api/optimizer_wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ Optimizer Wrappers
LookaheadState
masked
MaskedState
maybe_update
MaybeUpdateState
MultiSteps
MultiStepsState
ShouldSkipUpdateFunction
Expand Down
4 changes: 0 additions & 4 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@
from optax._src.utils import multi_normal
from optax._src.utils import scale_gradient
from optax._src.utils import value_and_grad_from_state
from optax._src.wrappers import maybe_update
from optax._src.wrappers import MaybeUpdateState

# TODO(mtthss): remove tree_utils aliases after updates.
adaptive_grad_clip = transforms.adaptive_grad_clip
Expand Down Expand Up @@ -372,8 +370,6 @@
"MaskOrFn",
"MaskedState",
"matrix_inverse_pth_root",
"maybe_update",
"MaybeUpdateState",
"multi_normal",
"multi_transform",
"MultiSteps",
Expand Down
22 changes: 0 additions & 22 deletions optax/_src/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
# ==============================================================================
"""Transformation wrappers."""

from collections.abc import Callable
import functools

import chex
import jax.numpy as jnp
from optax._src import base
from optax.transforms import _accumulation
from optax.transforms import _conditionality
from optax.transforms import _layouts
Expand All @@ -42,19 +36,3 @@
ShouldSkipUpdateFunction = _accumulation.ShouldSkipUpdateFunction
skip_not_finite = _accumulation.skip_not_finite
skip_large_updates = _accumulation.skip_large_updates


@functools.partial(
chex.warn_deprecated_function,
replacement='optax.transforms.maybe_transform',
)
def maybe_update(
inner: base.GradientTransformation,
should_update_fn: Callable[[jnp.ndarray], jnp.ndarray],
) -> base.GradientTransformationExtraArgs:
return conditionally_transform(
inner=inner, should_transform_fn=should_update_fn
)


MaybeUpdateState = ConditionallyTransformState
Loading