From 7e74ec758324fff86357caed203f26ad46effc46 Mon Sep 17 00:00:00 2001 From: Tony-Y <11532812+Tony-Y@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:47:08 +0900 Subject: [PATCH 1/3] Update Documentation --- .github/workflows/sphinx-gh-pages.yml | 2 +- LICENSE | 2 +- MANIFEST.in | 2 + README.md | 41 +++++--- docs/conf.py | 9 +- docs/index.rst | 73 +++++++++++++ examples/emnist/README.md | 93 +++++++++++++++++ examples/plots/README.md | 57 ++++++++++ pytorch_warmup/base.py | 143 ++++++++++++++++++++++---- pytorch_warmup/radam.py | 101 +++++++++++++++++- pytorch_warmup/untuned.py | 117 +++++++++++++++++++-- 11 files changed, 594 insertions(+), 46 deletions(-) create mode 100644 examples/emnist/README.md create mode 100644 examples/plots/README.md diff --git a/.github/workflows/sphinx-gh-pages.yml b/.github/workflows/sphinx-gh-pages.yml index a018492..43da29c 100644 --- a/.github/workflows/sphinx-gh-pages.yml +++ b/.github/workflows/sphinx-gh-pages.yml @@ -35,7 +35,7 @@ jobs: run: | python -m pip install --upgrade pip pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install sphinx sphinxcontrib-katex sphinx-rtd-theme + pip install sphinx sphinxcontrib-katex sphinx-copybutton sphinx-rtd-theme - name: Sphinx Build run: | cd docs/ diff --git a/LICENSE b/LICENSE index dc0eee3..19c1d48 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2019 Takenori Yamamoto +Copyright (c) 2019-2024 Takenori Yamamoto Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/MANIFEST.in b/MANIFEST.in index 1cde148..5646fca 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,7 @@ include test/__init__.py +include examples/plots/README.md include examples/plots/*.py include examples/plots/figs/*.png +include examples/emnist/README.md include examples/emnist/*.py include examples/emnist/figs/*.png diff --git a/README.md b/README.md index 3697e77..fb43b61 100644 --- a/README.md +++ b/README.md @@ -4,27 +4,36 @@ This library contains PyTorch implementations of the warmup schedules described

Warmup schedule

-![Python package](https://github.com/Tony-Y/pytorch_warmup/workflows/Python%20package/badge.svg) +[![Python package](https://github.com/Tony-Y/pytorch_warmup/workflows/Python%20package/badge.svg)](https://github.com/Tony-Y/pytorch_warmup/) [![PyPI version shields.io](https://img.shields.io/pypi/v/pytorch-warmup.svg)](https://pypi.python.org/pypi/pytorch-warmup/) -[![PyPI license](https://img.shields.io/pypi/l/pytorch-warmup.svg)](https://pypi.python.org/pypi/pytorch-warmup/) -[![PyPI pyversions](https://img.shields.io/pypi/pyversions/pytorch-warmup.svg)](https://pypi.python.org/pypi/pytorch-warmup/) +[![PyPI license](https://img.shields.io/pypi/l/pytorch-warmup.svg)](https://github.com/Tony-Y/pytorch_warmup/blob/master/LICENSE) +[![Python versions](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)](https://www.python.org) ## Installation -Make sure you have Python 3.7+ and PyTorch 1.1+. Then, run the following command in the project directory: +Make sure you have Python 3.7+ and PyTorch 1.1+ or 2.x. Then, run the following command in the project directory: -``` +```shell python -m pip install . ``` or install the latest version from the Python Package Index: -``` +```shell pip install -U pytorch_warmup ``` +## Examples + +* [EMNIST](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/emnist) - + A sample script to train a CNN model on the EMNIST dataset using the Adam algorithm with a warmup. +* [Plots](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/plots) - + A script to plot effective warmup periods as a function of 𝛽₂, and warmup schedules over time. + ## Usage +The [Documentation](https://tony-y.github.io/pytorch_warmup/) provides more detailed information on this library, unseen below. + ### Sample Codes The scheduled learning rate is dampened by the multiplication of the warmup factor: @@ -34,16 +43,20 @@ The scheduled learning rate is dampened by the multiplication of the warmup fact #### Approach 1 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_Approach1_chaining.ipynb) -When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used as follows: +When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used +together with `Adam` or its variant (`AdamW`, `NAdam`, etc.) as follows: ```python import torch import pytorch_warmup as warmup optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01) + # This sample code uses the AdamW optimizer. num_steps = len(dataloader) * num_epochs lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps) + # The LR schedule initialization resets the initial LR of the optimizer. warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) + # The warmup schedule initialization dampens the initial LR of the optimizer. for epoch in range(1,num_epochs+1): for batch in dataloader: optimizer.zero_grad() @@ -53,9 +66,9 @@ for epoch in range(1,num_epochs+1): with warmup_scheduler.dampening(): lr_scheduler.step() ``` -Note that the warmup schedule must not be initialized before the learning rate schedule. +Note that the warmup schedule must not be initialized before the initialization of the learning rate schedule. -If you want to use the learning rate schedule "chaining" which is supported for PyTorch 1.4.0 or above, you may simply give a code of learning rate schedulers as a suite of the `with` statement: +If you want to use the learning rate schedule *chaining*, which is supported for PyTorch 1.4 or above, you may simply write a code of learning rate schedulers as a suite of the `with` statement: ```python lr_scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) @@ -163,7 +176,7 @@ warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=1000) #### Untuned Warmup -The warmup period is given by a function of Adam's `beta2` parameter for `UntunedLinearWarmup` and `UntunedExponentialWarmup`. +The warmup period is determined by a function of Adam's `beta2` parameter for `UntunedLinearWarmup` and `UntunedExponentialWarmup`. ##### Linear @@ -183,7 +196,9 @@ warmup_scheduler = warmup.UntunedExponentialWarmup(optimizer) #### RAdam Warmup -The warmup factor depends on Adam's `beta2` parameter for `RAdamWarmup`. Please see the original paper for the details. +The warmup factor depends on Adam's `beta2` parameter for `RAdamWarmup`. For details please refer to the +[Documentation](https://tony-y.github.io/pytorch_warmup/radam_warmup.html) or +"[On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)." ```python warmup_scheduler = warmup.RAdamWarmup(optimizer) @@ -191,7 +206,7 @@ warmup_scheduler = warmup.RAdamWarmup(optimizer) ### Apex's Adam -The Apex library provides an Adam optimizer tuned for CUDA devices, [FusedAdam](https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedAdam). The FusedAdam optimizer can be used with the warmup schedulers. For example: +The Apex library provides an Adam optimizer tuned for CUDA devices, [FusedAdam](https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedAdam). The FusedAdam optimizer can be used together with any one of the warmup schedules above. For example: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_FusedAdam.ipynb) @@ -206,4 +221,4 @@ warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) MIT License -Copyright (c) 2019 Takenori Yamamoto +© 2019-2024 Takenori Yamamoto diff --git a/docs/conf.py b/docs/conf.py index f6e8abd..fbf139e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,7 @@ # -- Project information ----------------------------------------------------- project = 'PyTorch Warmup' -copyright = '2019, Takenori Yamamoto' +copyright = '2019-2024, Takenori Yamamoto' author = 'Takenori Yamamoto' @@ -40,6 +40,7 @@ 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinxcontrib.katex', + 'sphinx_copybutton', ] # Add any paths that contain templates here, relative to this directory. @@ -61,4 +62,8 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] + +# Copybutton settings +copybutton_prompt_text = r">>> |\.\.\. |\$ " +copybutton_prompt_is_regexp = True diff --git a/docs/index.rst b/docs/index.rst index dc6ccc6..862f56a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,79 @@ This library contains PyTorch implementations of the warmup schedules described `On the adequacy of untuned warmup for adaptive optimization `_. +.. image:: https://github.com/Tony-Y/pytorch_warmup/raw/master/examples/plots/figs/warmup_schedule.png + :alt: Warmup schedule + :width: 400 + :align: center + +.. image:: https://github.com/Tony-Y/pytorch_warmup/workflows/Python%20package/badge.svg + :alt: Python package + :target: https://github.com/Tony-Y/pytorch_warmup/ + +.. image:: https://img.shields.io/pypi/v/pytorch-warmup.svg + :alt: PyPI version shields.io + :target: https://pypi.python.org/pypi/pytorch-warmup/ + +.. image:: https://img.shields.io/pypi/l/pytorch-warmup.svg + :alt: PyPI license + :target: https://github.com/Tony-Y/pytorch_warmup/blob/master/LICENSE + +.. image:: https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue + :alt: Python versions + :target: https://www.python.org + +Installation +------------ + +Make sure you have Python 3.7+ and PyTorch 1.1+ or 2.x. Then, install the latest version from the Python Package Index: + +.. code-block:: shell + + pip install -U pytorch_warmup + +Examples +-------- + +.. image:: https://colab.research.google.com/assets/colab-badge.svg + :alt: Open In Colab + :target: https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_Approach1_chaining.ipynb + +* `EMNIST `_ - + A sample script to train a CNN model on the EMNIST dataset using the Adam algorithm with a warmup. + +* `Plots `_ - + A script to plot effective warmup periods as a function of :math:`\beta_{2}`, and warmup schedules over time. + +Usage +----- + +When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used +together with :class:`Adam` or its variant (:class:`AdamW`, :class:`NAdam`, etc.) as follows: + +.. code-block:: python + + import torch + import pytorch_warmup as warmup + + optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01) + # This sample code uses the AdamW optimizer. + num_steps = len(dataloader) * num_epochs + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps) + # The LR schedule initialization resets the initial LR of the optimizer. + warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) + # The warmup schedule initialization dampens the initial LR of the optimizer. + for epoch in range(1,num_epochs+1): + for batch in dataloader: + optimizer.zero_grad() + loss = ... + loss.backward() + optimizer.step() + with warmup_scheduler.dampening(): + lr_scheduler.step() + +Note that the warmup schedule must not be initialized before the initialization of the learning rate schedule. +Other approaches can be found in `README `_. + .. toctree:: :maxdepth: 2 :caption: Contents: diff --git a/examples/emnist/README.md b/examples/emnist/README.md new file mode 100644 index 0000000..4c0490c --- /dev/null +++ b/examples/emnist/README.md @@ -0,0 +1,93 @@ +# EMNIST Example + +Requirements: `pytorch_warmup` and `torchvision`. + +

+ Accuracy
+ Test accuracy over time for each warmup schedule. +

+ +

+ Accuracy
+ Learning rate over time for each warmup schedule. +

+ +## Download EMNIST Dataset + +Run the Python script `download.py` to download the EMNIST dataset: + +```shell +python download.py +``` + +This script shows download progress: + +``` +Downloading zip archive +Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to .data/EMNIST/raw/gzip.zip +100.0% +``` + +## Train A CNN Model + +Run the Python script `main.py` to train a CNN model on the EMNIST dataset using the Adam algorithm. + +### Untuned Linear Warmup + +Train a CNN model with the *Untuned Linear Warmup* schedule: + +``` +python main.py --warmup linear +``` + +### Untuned Exponential Warmup + +Train a CNN model with the *Untuned Exponential Warmup* schedule: + +``` +python main.py --warmup exponential +``` + +### RAdam Warmup + +Train a CNN model with the *RAdam Warmup* schedule: + +``` +python main.py --warmup radam +``` + +### No Warmup + +Train a CNN model without warmup: + +``` +python main.py --warmup none +``` + +### Usage + +``` +usage: main.py [-h] [--batch-size N] [--test-batch-size N] [--epochs N] [--lr LR] + [--lr-min LM] [--wd WD] [--beta2 B2] [--no-cuda] [--seed S] + [--log-interval N] [--warmup {linear,exponential,radam,none}] [--save-model] + +PyTorch EMNIST Example + +options: + -h, --help show this help message and exit + --batch-size N input batch size for training (default: 64) + --test-batch-size N input batch size for testing (default: 1000) + --epochs N number of epochs to train (default: 10) + --lr LR base learning rate (default: 0.01) + --lr-min LM minimum learning rate (default: 1e-5) + --wd WD weight decay (default: 0.01) + --beta2 B2 Adam's beta2 parameter (default: 0.999) + --no-cuda disables CUDA training + --seed S random seed (default: 1) + --log-interval N how many batches to wait before logging training status + --warmup {linear,exponential,radam,none} + warmup schedule + --save-model For Saving the current Model +``` + +© 2024 Takenori Yamamoto \ No newline at end of file diff --git a/examples/plots/README.md b/examples/plots/README.md new file mode 100644 index 0000000..4bb3e3d --- /dev/null +++ b/examples/plots/README.md @@ -0,0 +1,57 @@ +# Plots + +Requirements: `pytorch_warmup` and `matplotlib`. + +## Effective Warmup Period + +

+ Warmup period
+ Effective warmup periods of RAdam and rule-of-thumb warmup schedules, as a function of 𝛽₂. +

+ +Run the Python script `effective_warmup_period.py` to show up the figure above: + +```shell +python effective_warmup_period.py +``` + +### Usage + +``` +usage: effective_warmup_period.py [-h] [--output {none,png,pdf}] + +Effective warmup period + +options: + -h, --help show this help message and exit + --output {none,png,pdf} + Output file type (default: none) +``` + +## Warmup Schedule + +

+ Warmup schedule
+ RAdam and rule-of-thumb warmup schedules over time for 𝛽₂ = 0.999. +

+ +Run the Python script `warmup_schedule.py` to show up the figure above: + +```shell +python warmup_schedule.py +``` + +### Usage + +``` +usage: warmup_schedule.py [-h] [--output {none,png,pdf}] + +Warmup schedule + +options: + -h, --help show this help message and exit + --output {none,png,pdf} + Output file type (default: none) +``` + +© 2024 Takenori Yamamoto \ No newline at end of file diff --git a/pytorch_warmup/base.py b/pytorch_warmup/base.py index fea9a39..ed0fe66 100644 --- a/pytorch_warmup/base.py +++ b/pytorch_warmup/base.py @@ -10,12 +10,21 @@ def _check_optimizer(optimizer): class BaseWarmup(object): - """Base class for all warmup schedules + """Base class for all warmup schedules. - Arguments: - optimizer (Optimizer): an instance of a subclass of Optimizer - warmup_params (list): warmup paramters - last_step (int): The index of last step. (Default: -1) + The learning rate :math:`\\alpha_{t}` is dampened by multiplying it by + the warmup factor :math:`\\omega_{t} \\in [0, 1]` at each iteration :math:`t`. + Thus, the modified learning rate + + .. math:: + \\hat \\alpha_{t} = \\alpha_{t} \\cdot \\omega_{t} + + is used by the optimizer. + + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_params (list): Warmup parameters. + last_step (int): The index of last step. Default: -1. """ def __init__(self, optimizer, warmup_params, last_step=-1): @@ -28,7 +37,7 @@ def __init__(self, optimizer, warmup_params, last_step=-1): def state_dict(self): """Returns the state of the warmup scheduler as a :class:`dict`. - It contains an entry for every variable in self.__dict__ which + It contains an entry for every variable in :attr:`self.__dict__` which is not the optimizer. """ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} @@ -36,17 +45,20 @@ def state_dict(self): def load_state_dict(self, state_dict): """Loads the warmup scheduler's state. - Arguments: - state_dict (dict): warmup scheduler state. Should be an object returned + Args: + state_dict (dict): Warmup scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) def dampen(self, step=None): - """Dampen the learning rates. + """Dampens the learning rate. - Arguments: - step (int): The index of current step. (Default: None) + It is not recommended to explicitly call this method for PyTorch 1.4.0 or later. + Please use the :meth:`dampening` context manager that calls this method correctly. + + Args: + step (int): The index of current step. Default: ``None``. """ if step is None: step = self.last_step + 1 @@ -58,6 +70,32 @@ def dampen(self, step=None): @contextmanager def dampening(self): + """Dampens the learning rate after calling the :meth:`step` method of the learning + rate scheduler. + + The :meth:`step` method calls must be placed in a suite of the ``with`` statement having + the :meth:`dampening` context manager. + + Examples: + >>> # For no LR scheduler + >>> with warmup_scheduler.dampening(): + >>> pass + + >>> # For a single LR scheduler + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler.step() + + >>> # To chain two LR schedulers + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler1.step() + >>> lr_scheduler2.step() + + >>> # To delay an LR scheduler + >>> iteration = warmup_scheduler.last_step + 1 + >>> with warmup_scheduler.dampening(): + >>> if iteration >= warmup_period: + >>> lr_scheduler.step() + """ for group, lr in zip(self.optimizer.param_groups, self.lrs): group['lr'] = lr yield @@ -65,6 +103,16 @@ def dampening(self): self.dampen() def warmup_factor(self, step, **params): + """Returns the warmup factor :math:`\\omega_{t}` at an iteration :math:`t`. + + :meth:`dampen` uses this method to get the warmup factor for each parameter group. + It is unnecessary to explicitly call this method. + + Args: + step (int): The index of current step. + params (dict): The warmup parameters. For details, refer to the arguments of + each subclass method. + """ raise NotImplementedError @@ -98,10 +146,32 @@ def get_warmup_params(warmup_period, group_count): class LinearWarmup(BaseWarmup): """Linear warmup schedule. - Arguments: - optimizer (Optimizer): an instance of a subclass of Optimizer - warmup_period (int or list): Warmup period - last_step (int): The index of last step. (Default: -1) + The linear warmup schedule uses the warmup factor + + .. math:: + \\omega_{t}^{\\rm linear, \\tau} = \\min \\{ 1, \\frac{1}{\\tau} \\cdot t \\} + + at each iteration :math:`t`, where :math:`\\tau` is the warmup period. + + Args: + optimizer (Optimizer): Wrapped optimizer. :class:`RAdam` is not suitable because of the + warmup redundancy. + warmup_period (int or list[int]): The warmup period :math:`\\tau`. + last_step (int): The index of last step. Default: -1. + + Example: + >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) + >>> warmup_scheduler = LinearWarmup(optimizer, warmup_period=2000) + >>> for batch in dataloader: + >>> optimizer.zero_grad() + >>> loss = ... + >>> loss.backward() + >>> optimizer.step() + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler.step() + + Note: + The warmup schedule must not be initialized before the initialization of the learning rate schedule. """ def __init__(self, optimizer, warmup_period, last_step=-1): @@ -111,16 +181,45 @@ def __init__(self, optimizer, warmup_period, last_step=-1): super(LinearWarmup, self).__init__(optimizer, warmup_params, last_step) def warmup_factor(self, step, warmup_period): + """Returns the warmup factor :math:`\\omega_{t}^{\\rm linear, \\tau}` at an iteration :math:`t`. + + Args: + step (int): The index of current step. + warmup_period (int): The warmup period :math:`\\tau`. + """ return min(1.0, (step+1) / warmup_period) class ExponentialWarmup(BaseWarmup): """Exponential warmup schedule. - Arguments: - optimizer (Optimizer): an instance of a subclass of Optimizer - warmup_period (int or list): Effective warmup period - last_step (int): The index of last step. (Default: -1) + The exponential warmup schedule uses the warmup factor + + .. math:: + \\omega_{t}^{\\rm expo, \\tau} = 1 - \\exp \\left( - \\frac{1}{\\tau} \\cdot t \\right) + + at each iteration :math:`t`, where the constant :math:`\\tau` is analogous to + a linear warmup period. + + Args: + optimizer (Optimizer): Wrapped optimizer. :class:`RAdam` is not suitable because of the + warmup redundancy. + warmup_period (int or list[int]): The constant :math:`\\tau` analogous to a linear warmup period. + last_step (int): The index of last step. Default: -1. + + Example: + >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) + >>> warmup_scheduler = ExponentialWarmup(optimizer, warmup_period=1000) + >>> for batch in dataloader: + >>> optimizer.zero_grad() + >>> loss = ... + >>> loss.backward() + >>> optimizer.step() + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler.step() + + Note: + The warmup schedule must not be initialized before the initialization of the learning rate schedule. """ def __init__(self, optimizer, warmup_period, last_step=-1): @@ -130,4 +229,10 @@ def __init__(self, optimizer, warmup_period, last_step=-1): super(ExponentialWarmup, self).__init__(optimizer, warmup_params, last_step) def warmup_factor(self, step, warmup_period): + """Returns the warmup factor :math:`\\omega_{t}^{\\rm expo, \\tau}` at an iteration :math:`t`. + + Args: + step (int): The index of current step. + warmup_period (int): The constant :math:`\\tau` analogous to a linear warmup period. + """ return 1.0 - math.exp(-(step+1) / warmup_period) diff --git a/pytorch_warmup/radam.py b/pytorch_warmup/radam.py index 6065e6f..54a64c7 100644 --- a/pytorch_warmup/radam.py +++ b/pytorch_warmup/radam.py @@ -3,16 +3,35 @@ def rho_inf_fn(beta2): + """Returns the constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. + + Args: + beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. + """ return 2.0 / (1 - beta2) - 1 def rho_fn(t, beta2, rho_inf): + """Returns the value of the function of the RAdam algorithm, :math:`\\rho_{t}`, + at an iteration :math:`t`. + + Args: + t (int): The iteration :math:`t`. + beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. + rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. + """ b2t = beta2 ** t rho_t = rho_inf - 2 * t * b2t / (1 - b2t) return rho_t def get_offset(beta2, rho_inf): + """Returns the minimal offset :math:`\\delta`. + + Args: + beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. + rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. + """ if not beta2 > 0.6: raise ValueError('beta2 ({}) must be greater than 0.6'.format(beta2)) offset = 1 @@ -29,9 +48,77 @@ class RAdamWarmup(BaseWarmup): `On the adequacy of untuned warmup for adaptive optimization `_. - Arguments: - optimizer (Optimizer): an Adam optimizer - last_step (int): The index of last step. (Default: -1) + The RAdam algorithm uses the warmup factor + + .. math:: + \\omega_{t}^{\\rm RAdam} = \\sqrt{ \\frac{ \\ + ( \\rho_{t} - 4 ) ( \\rho_{t} - 2 ) \\rho_{\\infty} }{ \\ + ( \\rho_{\\infty} - 4) (\\rho_{\\infty} - 2 ) \\rho_{t} } } + + at each iteration :math:`t` for :math:`\\rho_{t} > 4`, where + + .. math:: + \\rho_{\\infty} = \\frac{ 2 }{ 1 - \\beta_{2} } - 1 + + and + + .. math:: + \\rho_{t} = \\rho_{\\infty} - \\frac{ 2 t \\cdot \\beta_{2}^{t} }{ 1 - \\beta_{2}^{t} } + + where :math:`\\beta_{2}` is the second discount factor of Adam. In the RAdam warmup schedule, + the minimal offset :math:`\\delta` is chosen such that :math:`\\rho_{\\delta} > 4`, and then + :math:`\\omega_{t+\\delta-1}^{\\rm RAdam}` is employed as the warmup factor at each iteration :math:`t`. + For all practically relevant values of :math:`\\beta_{2}` (:math:`0.8 < \\beta_{2} \\le 1`), + :math:`\\delta \\le 5` as deduced from Fact 3.1 of the paper. + + Args: + optimizer (Optimizer): Adam optimizer or its variant: + :class:`Adam`, :class:`AdamW`, :class:`SparseAdam`, or :class:`NAdam`. + :class:`RAdam` is not suitable because of the warmup redundancy. This warmup + schedule makes no sense for :class:`Adamax` and, in principle, the AMSGrad variant of + :class:`Adam` and :class:`AdamW` as discussed in Note below. In practice, this warmup + schedule improves the performance of the AMSGrad variant like that of the vanilla Adam. + last_step (int): The index of last step. Default: -1. + + Note: + This warmup schedule employs the same warmup factor for all variants of Adam. However, + according to the RAdam theory, + :class:`Adamax` and the AMSGrad variant of :class:`Adam` and :class:`AdamW` should + have a different warmup factor because its :math:`\\psi(\\cdot)` function is different from one of the + vanilla Adam, where :math:`\\psi(\\cdot)` specifies how the adaptive learning rate at :math:`t` is + calculated. The RAdam theory derives the warmup factor :math:`\\omega_{t}` from + :math:`\\psi(\\cdot)`. For gradients :math:`\\left\\{ g_{i} \\right\\}` viewed as i.i.d. normal random + variables, + + .. centered:: + :math:`\\omega_{t} = \\sqrt{ C_{\\rm var} / {\\rm Var}\\left[ \\psi(g_{1}, \\dots, g_{t}) \\right] }` + + where + + .. centered:: + :math:`C_{\\rm var} = \\inf_{t} {\\rm Var}\\left[ \\psi(g_{1}, \\dots, g_{t}) \\right]`. + + (For details please refer to `On the Variance of the Adaptive Learning Rate and Beyond + `_.) + + The variance hypothesis of the RAdam theory has become questionable + since Ma and Yarats' paper pointed out that the adaptive learning rate may not be the best medium + of analysis for understanding the role of warmup in Adam. + + Example: + >>> optimizer = AdamW(...) + >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) + >>> warmup_scheduler = RAdamWarmup(optimizer) + >>> for batch in dataloader: + >>> optimizer.zero_grad() + >>> loss = ... + >>> loss.backward() + >>> optimizer.step() + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler.step() + + Note: + The warmup schedule must not be initialized before the initialization of the learning rate schedule. """ def __init__(self, optimizer, last_step=-1): @@ -48,6 +135,14 @@ def __init__(self, optimizer, last_step=-1): super(RAdamWarmup, self).__init__(optimizer, warmup_params, last_step) def warmup_factor(self, step, beta2, rho_inf, offset): + """Returns the warmup factor :math:`\\omega_{t+\\delta-1}^{\\rm RAdam}` at an iteration :math:`t`. + + Args: + step (int): The index of current step. + beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. + rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. + offset (int): The minimal offset :math:`\\delta`. + """ rho = rho_fn(step+offset, beta2, rho_inf) numerator = (rho - 4) * (rho - 2) * rho_inf denominator = (rho_inf - 4) * (rho_inf - 2) * rho diff --git a/pytorch_warmup/untuned.py b/pytorch_warmup/untuned.py index 9ceaacb..76c4155 100644 --- a/pytorch_warmup/untuned.py +++ b/pytorch_warmup/untuned.py @@ -8,9 +8,56 @@ class UntunedLinearWarmup(LinearWarmup): `On the adequacy of untuned warmup for adaptive optimization `_. - Arguments: - optimizer (Optimizer): an Adam optimizer - last_step (int): The index of last step. (Default: -1) + The untuned linear warmup schedule uses the warmup factor + + .. math:: + \\omega_{t}^{\\rm linear, untuned} = \\min \\{ 1, \\frac{1 - \\beta_{2}}{2} \\cdot t \\} + + at each iteration :math:`t`, where :math:`\\beta_{2}` is the second discount factor of Adam. + In practice, :math:`\\omega_{t}^{\\rm linear, untuned}` is calculated as + :math:`\\omega_{t}^{\\rm linear, \\tau}` with :math:`\\tau = \\frac{2}{1 - \\beta_{2}}`. + + Note: + The effective warmup period is defined as + + .. centered:: + :math:`{\\cal T}(\\omega) = \\sum_{t = 1}^{\\infty} \\left( 1 - \\omega_{t} \\right)` + + for a warmup schedule :math:`\\omega = \\left\\{ \\omega_{t} \\right\\}_{t=1}^{\\infty}`. + The warmup period :math:`\\tau` is deduced from solving approximately the rough equivalence: + + .. centered:: + :math:`{\\cal T}(\\omega^{\\rm expo, untuned}) \\approx {\\cal T}(\\omega^{{\\rm linear}, + \\tau}) \\approx \\frac{\\tau}{2}`. + + Args: + optimizer (Optimizer): Adam optimizer or its variant: + :class:`Adam`, :class:`AdamW`, :class:`SparseAdam`, or :class:`NAdam`. + :class:`RAdam` is not suitable because of the warmup redundancy. This warmup + schedule makes no sense for :class:`Adamax` as discussed in Note below. + last_step (int): The index of last step. Default: -1. + + Note: + This warmup schedule employs the same warmup period :math:`\\tau` for all variants of Adam. However, + :class:`Adamax` should in principle need no linear warmup because it needs no exponential warmup. + For further details please refer to Note in the documentation of :class:`UntunedExponentialWarmup`. + In practice, a linear warmup may slightly improve AdaMax's performance because the initial update step + is the same as one of the Adam optimizer. + + Example: + >>> optimizer = AdamW(...) + >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) + >>> warmup_scheduler = UntunedLinearWarmup(optimizer) + >>> for batch in dataloader: + >>> optimizer.zero_grad() + >>> loss = ... + >>> loss.backward() + >>> optimizer.step() + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler.step() + + Note: + The warmup schedule must not be initialized before the initialization of the learning rate schedule. """ def __init__(self, optimizer, last_step=-1): @@ -23,15 +70,71 @@ def warmup_period_fn(beta2): class UntunedExponentialWarmup(ExponentialWarmup): - """Untuned exponetial warmup schedule for Adam. + """Untuned exponential warmup schedule for Adam. This warmup scheme is described in `On the adequacy of untuned warmup for adaptive optimization `_. - Arguments: - optimizer (Optimizer): an Adam optimizer - last_step (int): The index of last step. (Default: -1) + The untuned exponential warmup schedule uses the warmup factor + + .. math:: + \\omega_{t}^{\\rm expo, untuned} = 1 - \\exp \\left( - (1 - \\beta_{2}) \\cdot t \\right) + + at each iteration :math:`t`, where :math:`\\beta_{2}` is the second discount factor of Adam. + In practice, :math:`\\omega_{t}^{\\rm expo, untuned}` is calculated as + :math:`\\omega_{t}^{\\rm expo, \\tau}` with :math:`\\tau = \\frac{1}{1 - \\beta_{2}}`. + + Note: + The constant :math:`\\tau` is derived from the intuition that + the warmup factor should be roughly equivalent to Adam's second moment bias correction term, + :math:`1 - \\beta_{2}^{t}`. + + Note: + The effective warmup period is defined as + + .. centered:: + :math:`{\\cal T}(\\omega) = \\sum_{t = 1}^{\\infty} \\left( 1 - \\omega_{t} \\right)` + + for a warmup schedule :math:`\\omega = \\left\\{ \\omega_{t} \\right\\}_{t=1}^{\\infty}`. + The constant :math:`\\tau` of the untuned exponential warmup schedule is roughly equivalent to + its effective warmup period: + + .. centered:: + :math:`{\\cal T}(\\omega^{\\rm expo, untuned}) = 1 / \\left( \\exp( 1 - \\beta_{2}) - 1 \\right) \\approx \\tau` + + for :math:`\\beta_{2}` near 1. The rough equivalence is also achieved for an exponential warmup schedule + if its :math:`\\tau` is large enough, for example, :math:`\\tau \\ge 1`. + + Args: + optimizer (Optimizer): Adam optimizer or its variant: + :class:`Adam`, :class:`AdamW`, :class:`SparseAdam`, or :class:`NAdam`. + :class:`RAdam` is not suitable because of the warmup redundancy. This warmup + schedule makes no sense for :class:`Adamax` as discussed in Note below. + last_step (int): The index of last step. Default: -1. + + Note: + This warmup schedule employs the same constant :math:`\\tau` for all variants of Adam. However, + :class:`Adamax` should in principle need no warmup because :class:`Adamax` is derived by employing + a :math:`L^{p}` norm update rule and letting :math:`p \\rightarrow \\infty`, and the second moment bias + correction term is :math:`1-\\beta_{2}^{pt}`, to which the warmup factor must be roughly equivalent + in this warmup schedule derivation. In practice, an exponential warmup may slightly improve AdaMax's + performance because the initial update step is the same as one of the Adam optimizer. + + Example: + >>> optimizer = AdamW(...) + >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) + >>> warmup_scheduler = UntunedExponentialWarmup(optimizer) + >>> for batch in dataloader: + >>> optimizer.zero_grad() + >>> loss = ... + >>> loss.backward() + >>> optimizer.step() + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler.step() + + Note: + The warmup schedule must not be initialized before the initialization of the learning rate schedule. """ def __init__(self, optimizer, last_step=-1): From c2b6bcdfb06d4ffb309f145ffe819240ea14dfc5 Mon Sep 17 00:00:00 2001 From: Tony-Y <11532812+Tony-Y@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:25:34 +0900 Subject: [PATCH 2/3] Use HTML symbols --- README.md | 4 ++-- examples/emnist/README.md | 2 +- examples/plots/README.md | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index fb43b61..fef7edf 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ pip install -U pytorch_warmup * [EMNIST](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/emnist) - A sample script to train a CNN model on the EMNIST dataset using the Adam algorithm with a warmup. * [Plots](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/plots) - - A script to plot effective warmup periods as a function of 𝛽₂, and warmup schedules over time. + A script to plot effective warmup periods as a function of β₂, and warmup schedules over time. ## Usage @@ -221,4 +221,4 @@ warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) MIT License -© 2019-2024 Takenori Yamamoto +© 2019-2024 Takenori Yamamoto diff --git a/examples/emnist/README.md b/examples/emnist/README.md index 4c0490c..04d6050 100644 --- a/examples/emnist/README.md +++ b/examples/emnist/README.md @@ -90,4 +90,4 @@ options: --save-model For Saving the current Model ``` -© 2024 Takenori Yamamoto \ No newline at end of file +© 2024 Takenori Yamamoto \ No newline at end of file diff --git a/examples/plots/README.md b/examples/plots/README.md index 4bb3e3d..536972d 100644 --- a/examples/plots/README.md +++ b/examples/plots/README.md @@ -6,7 +6,7 @@ Requirements: `pytorch_warmup` and `matplotlib`.

Warmup period
- Effective warmup periods of RAdam and rule-of-thumb warmup schedules, as a function of 𝛽₂. + Effective warmup periods of RAdam and rule-of-thumb warmup schedules, as a function of β₂.

Run the Python script `effective_warmup_period.py` to show up the figure above: @@ -32,7 +32,7 @@ options:

Warmup schedule
- RAdam and rule-of-thumb warmup schedules over time for 𝛽₂ = 0.999. + RAdam and rule-of-thumb warmup schedules over time for β₂ = 0.999.

Run the Python script `warmup_schedule.py` to show up the figure above: @@ -54,4 +54,4 @@ options: Output file type (default: none) ``` -© 2024 Takenori Yamamoto \ No newline at end of file +© 2024 Takenori Yamamoto \ No newline at end of file From e8625527bbb5df2bc0d109b0801fc011d045bc58 Mon Sep 17 00:00:00 2001 From: Tony-Y <11532812+Tony-Y@users.noreply.github.com> Date: Thu, 10 Oct 2024 20:32:54 +0900 Subject: [PATCH 3/3] Remove test_suite from setuptools options --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index f9db849..615def1 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,5 @@ "License :: OSI Approved :: MIT License", ], python_requires='>=3.7', - test_suite='test', install_requires=['torch>=1.1'] )