Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AdEMAMix Optimizer #1057

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Optimizers
adamw
adamax
adamaxw
ademamix
amsgrad
fromage
lamb
Expand Down Expand Up @@ -64,6 +65,10 @@ AdamW
~~~~~
.. autofunction:: adamw

Ademamix
~~~~~
.. autofunction:: ademamix

AMSGrad
~~~~~~~
.. autofunction:: amsgrad
Expand Down
33 changes: 21 additions & 12 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from optax._src.alias import adamax
from optax._src.alias import adamaxw
from optax._src.alias import adamw
from optax._src.alias import ademamix
from optax._src.alias import amsgrad
from optax._src.alias import fromage
from optax._src.alias import lamb
Expand Down Expand Up @@ -99,7 +100,6 @@
from optax._src.lookahead import lookahead
from optax._src.lookahead import LookaheadParams
from optax._src.lookahead import LookaheadState
from optax._src.numerics import safe_increment
from optax._src.numerics import safe_int32_increment
from optax._src.numerics import safe_norm
from optax._src.numerics import safe_root_mean_squares
Expand All @@ -117,6 +117,7 @@
from optax._src.transform import scale_by_adadelta
from optax._src.transform import scale_by_adam
from optax._src.transform import scale_by_adamax
from optax._src.transform import scale_by_ademamix
from optax._src.transform import scale_by_amsgrad
from optax._src.transform import scale_by_belief
from optax._src.transform import scale_by_distance_over_gradients
Expand All @@ -140,6 +141,7 @@
from optax._src.transform import scale_by_yogi
from optax._src.transform import ScaleByAdaDeltaState
from optax._src.transform import ScaleByAdamState
from optax._src.transform import ScaleByAdemamixState
from optax._src.transform import ScaleByAmsgradState
from optax._src.transform import ScaleByBeliefState
from optax._src.transform import ScaleByLBFGSState
Expand Down Expand Up @@ -232,8 +234,12 @@
# pylint: disable=g-import-not-at-top
# TODO(mtthss): remove contrib aliases from flat namespace once users updated.
# Deprecated modules
from optax.contrib import differentially_private_aggregate as _deprecated_differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState as _deprecated_DifferentiallyPrivateAggregateState
from optax.contrib import (
differentially_private_aggregate as _deprecated_differentially_private_aggregate,
)
from optax.contrib import (
DifferentiallyPrivateAggregateState as _deprecated_DifferentiallyPrivateAggregateState,
)
from optax.contrib import dpsgd as _deprecated_dpsgd

_deprecations = {
Expand Down Expand Up @@ -266,17 +272,18 @@
import typing as _typing

if _typing.TYPE_CHECKING:
# pylint: disable=reimported
from optax.contrib import differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState
from optax.contrib import dpsgd
# pylint: enable=reimported
# pylint: disable=reimported
from optax.contrib import differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState
from optax.contrib import dpsgd

# pylint: enable=reimported

else:
from optax._src.deprecations import deprecation_getattr as _deprecation_getattr
from optax._src.deprecations import deprecation_getattr as _deprecation_getattr

__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
# pylint: enable=g-bad-import-order
# pylint: enable=g-import-not-at-top
Expand All @@ -300,6 +307,7 @@
"add_noise",
"AddDecayedWeightsState",
"AddNoiseState",
"ademamix",
"amsgrad",
"apply_every",
"apply_if_finite",
Expand Down Expand Up @@ -388,14 +396,14 @@
"radam",
"rmsprop",
"rprop",
"safe_increment",
"safe_int32_increment",
"safe_norm",
"safe_root_mean_squares",
"ScalarOrSchedule",
"scale_by_adadelta",
"scale_by_adam",
"scale_by_adamax",
"scale_by_ademamix",
"scale_by_amsgrad",
"scale_by_backtracking_linesearch",
"scale_by_belief",
Expand All @@ -421,6 +429,7 @@
"scale",
"ScaleByAdaDeltaState",
"ScaleByAdamState",
"ScaleByAdemamixState",
"ScaleByAmsgradState",
"ScaleByBacktrackingLinesearchState",
"ScaleByBeliefState",
Expand Down
Loading
Loading