diff --git a/.github/workflows/sphinx-gh-pages.yml b/.github/workflows/sphinx-gh-pages.yml index a018492..43da29c 100644 --- a/.github/workflows/sphinx-gh-pages.yml +++ b/.github/workflows/sphinx-gh-pages.yml @@ -35,7 +35,7 @@ jobs: run: | python -m pip install --upgrade pip pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install sphinx sphinxcontrib-katex sphinx-rtd-theme + pip install sphinx sphinxcontrib-katex sphinx-copybutton sphinx-rtd-theme - name: Sphinx Build run: | cd docs/ diff --git a/LICENSE b/LICENSE index dc0eee3..19c1d48 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2019 Takenori Yamamoto +Copyright (c) 2019-2024 Takenori Yamamoto Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/MANIFEST.in b/MANIFEST.in index 1cde148..5646fca 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,7 @@ include test/__init__.py +include examples/plots/README.md include examples/plots/*.py include examples/plots/figs/*.png +include examples/emnist/README.md include examples/emnist/*.py include examples/emnist/figs/*.png diff --git a/README.md b/README.md index 3697e77..fef7edf 100644 --- a/README.md +++ b/README.md @@ -4,27 +4,36 @@ This library contains PyTorch implementations of the warmup schedules described
-![Python package](https://github.com/Tony-Y/pytorch_warmup/workflows/Python%20package/badge.svg) +[![Python package](https://github.com/Tony-Y/pytorch_warmup/workflows/Python%20package/badge.svg)](https://github.com/Tony-Y/pytorch_warmup/) [![PyPI version shields.io](https://img.shields.io/pypi/v/pytorch-warmup.svg)](https://pypi.python.org/pypi/pytorch-warmup/) -[![PyPI license](https://img.shields.io/pypi/l/pytorch-warmup.svg)](https://pypi.python.org/pypi/pytorch-warmup/) -[![PyPI pyversions](https://img.shields.io/pypi/pyversions/pytorch-warmup.svg)](https://pypi.python.org/pypi/pytorch-warmup/) +[![PyPI license](https://img.shields.io/pypi/l/pytorch-warmup.svg)](https://github.com/Tony-Y/pytorch_warmup/blob/master/LICENSE) +[![Python versions](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)](https://www.python.org) ## Installation -Make sure you have Python 3.7+ and PyTorch 1.1+. Then, run the following command in the project directory: +Make sure you have Python 3.7+ and PyTorch 1.1+ or 2.x. Then, run the following command in the project directory: -``` +```shell python -m pip install . ``` or install the latest version from the Python Package Index: -``` +```shell pip install -U pytorch_warmup ``` +## Examples + +* [EMNIST](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/emnist) - + A sample script to train a CNN model on the EMNIST dataset using the Adam algorithm with a warmup. +* [Plots](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/plots) - + A script to plot effective warmup periods as a function of β₂, and warmup schedules over time. + ## Usage +The [Documentation](https://tony-y.github.io/pytorch_warmup/) provides more detailed information on this library, unseen below. + ### Sample Codes The scheduled learning rate is dampened by the multiplication of the warmup factor: @@ -34,16 +43,20 @@ The scheduled learning rate is dampened by the multiplication of the warmup fact #### Approach 1 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_Approach1_chaining.ipynb) -When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used as follows: +When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used +together with `Adam` or its variant (`AdamW`, `NAdam`, etc.) as follows: ```python import torch import pytorch_warmup as warmup optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01) + # This sample code uses the AdamW optimizer. num_steps = len(dataloader) * num_epochs lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps) + # The LR schedule initialization resets the initial LR of the optimizer. warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) + # The warmup schedule initialization dampens the initial LR of the optimizer. for epoch in range(1,num_epochs+1): for batch in dataloader: optimizer.zero_grad() @@ -53,9 +66,9 @@ for epoch in range(1,num_epochs+1): with warmup_scheduler.dampening(): lr_scheduler.step() ``` -Note that the warmup schedule must not be initialized before the learning rate schedule. +Note that the warmup schedule must not be initialized before the initialization of the learning rate schedule. -If you want to use the learning rate schedule "chaining" which is supported for PyTorch 1.4.0 or above, you may simply give a code of learning rate schedulers as a suite of the `with` statement: +If you want to use the learning rate schedule *chaining*, which is supported for PyTorch 1.4 or above, you may simply write a code of learning rate schedulers as a suite of the `with` statement: ```python lr_scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) @@ -163,7 +176,7 @@ warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=1000) #### Untuned Warmup -The warmup period is given by a function of Adam's `beta2` parameter for `UntunedLinearWarmup` and `UntunedExponentialWarmup`. +The warmup period is determined by a function of Adam's `beta2` parameter for `UntunedLinearWarmup` and `UntunedExponentialWarmup`. ##### Linear @@ -183,7 +196,9 @@ warmup_scheduler = warmup.UntunedExponentialWarmup(optimizer) #### RAdam Warmup -The warmup factor depends on Adam's `beta2` parameter for `RAdamWarmup`. Please see the original paper for the details. +The warmup factor depends on Adam's `beta2` parameter for `RAdamWarmup`. For details please refer to the +[Documentation](https://tony-y.github.io/pytorch_warmup/radam_warmup.html) or +"[On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)." ```python warmup_scheduler = warmup.RAdamWarmup(optimizer) @@ -191,7 +206,7 @@ warmup_scheduler = warmup.RAdamWarmup(optimizer) ### Apex's Adam -The Apex library provides an Adam optimizer tuned for CUDA devices, [FusedAdam](https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedAdam). The FusedAdam optimizer can be used with the warmup schedulers. For example: +The Apex library provides an Adam optimizer tuned for CUDA devices, [FusedAdam](https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedAdam). The FusedAdam optimizer can be used together with any one of the warmup schedules above. For example: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_FusedAdam.ipynb) @@ -206,4 +221,4 @@ warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) MIT License -Copyright (c) 2019 Takenori Yamamoto +© 2019-2024 Takenori Yamamoto diff --git a/docs/conf.py b/docs/conf.py index f6e8abd..fbf139e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,7 @@ # -- Project information ----------------------------------------------------- project = 'PyTorch Warmup' -copyright = '2019, Takenori Yamamoto' +copyright = '2019-2024, Takenori Yamamoto' author = 'Takenori Yamamoto' @@ -40,6 +40,7 @@ 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinxcontrib.katex', + 'sphinx_copybutton', ] # Add any paths that contain templates here, relative to this directory. @@ -61,4 +62,8 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] + +# Copybutton settings +copybutton_prompt_text = r">>> |\.\.\. |\$ " +copybutton_prompt_is_regexp = True diff --git a/docs/index.rst b/docs/index.rst index dc6ccc6..862f56a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,79 @@ This library contains PyTorch implementations of the warmup schedules described `On the adequacy of untuned warmup for adaptive optimization+ + Test accuracy over time for each warmup schedule. +
+ ++ + Learning rate over time for each warmup schedule. +
+ +## Download EMNIST Dataset + +Run the Python script `download.py` to download the EMNIST dataset: + +```shell +python download.py +``` + +This script shows download progress: + +``` +Downloading zip archive +Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to .data/EMNIST/raw/gzip.zip +100.0% +``` + +## Train A CNN Model + +Run the Python script `main.py` to train a CNN model on the EMNIST dataset using the Adam algorithm. + +### Untuned Linear Warmup + +Train a CNN model with the *Untuned Linear Warmup* schedule: + +``` +python main.py --warmup linear +``` + +### Untuned Exponential Warmup + +Train a CNN model with the *Untuned Exponential Warmup* schedule: + +``` +python main.py --warmup exponential +``` + +### RAdam Warmup + +Train a CNN model with the *RAdam Warmup* schedule: + +``` +python main.py --warmup radam +``` + +### No Warmup + +Train a CNN model without warmup: + +``` +python main.py --warmup none +``` + +### Usage + +``` +usage: main.py [-h] [--batch-size N] [--test-batch-size N] [--epochs N] [--lr LR] + [--lr-min LM] [--wd WD] [--beta2 B2] [--no-cuda] [--seed S] + [--log-interval N] [--warmup {linear,exponential,radam,none}] [--save-model] + +PyTorch EMNIST Example + +options: + -h, --help show this help message and exit + --batch-size N input batch size for training (default: 64) + --test-batch-size N input batch size for testing (default: 1000) + --epochs N number of epochs to train (default: 10) + --lr LR base learning rate (default: 0.01) + --lr-min LM minimum learning rate (default: 1e-5) + --wd WD weight decay (default: 0.01) + --beta2 B2 Adam's beta2 parameter (default: 0.999) + --no-cuda disables CUDA training + --seed S random seed (default: 1) + --log-interval N how many batches to wait before logging training status + --warmup {linear,exponential,radam,none} + warmup schedule + --save-model For Saving the current Model +``` + +© 2024 Takenori Yamamoto \ No newline at end of file diff --git a/examples/plots/README.md b/examples/plots/README.md new file mode 100644 index 0000000..536972d --- /dev/null +++ b/examples/plots/README.md @@ -0,0 +1,57 @@ +# Plots + +Requirements: `pytorch_warmup` and `matplotlib`. + +## Effective Warmup Period + ++ + Effective warmup periods of RAdam and rule-of-thumb warmup schedules, as a function of β₂. +
+ +Run the Python script `effective_warmup_period.py` to show up the figure above: + +```shell +python effective_warmup_period.py +``` + +### Usage + +``` +usage: effective_warmup_period.py [-h] [--output {none,png,pdf}] + +Effective warmup period + +options: + -h, --help show this help message and exit + --output {none,png,pdf} + Output file type (default: none) +``` + +## Warmup Schedule + ++ + RAdam and rule-of-thumb warmup schedules over time for β₂ = 0.999. +
+ +Run the Python script `warmup_schedule.py` to show up the figure above: + +```shell +python warmup_schedule.py +``` + +### Usage + +``` +usage: warmup_schedule.py [-h] [--output {none,png,pdf}] + +Warmup schedule + +options: + -h, --help show this help message and exit + --output {none,png,pdf} + Output file type (default: none) +``` + +© 2024 Takenori Yamamoto \ No newline at end of file diff --git a/pytorch_warmup/base.py b/pytorch_warmup/base.py index fea9a39..ed0fe66 100644 --- a/pytorch_warmup/base.py +++ b/pytorch_warmup/base.py @@ -10,12 +10,21 @@ def _check_optimizer(optimizer): class BaseWarmup(object): - """Base class for all warmup schedules + """Base class for all warmup schedules. - Arguments: - optimizer (Optimizer): an instance of a subclass of Optimizer - warmup_params (list): warmup paramters - last_step (int): The index of last step. (Default: -1) + The learning rate :math:`\\alpha_{t}` is dampened by multiplying it by + the warmup factor :math:`\\omega_{t} \\in [0, 1]` at each iteration :math:`t`. + Thus, the modified learning rate + + .. math:: + \\hat \\alpha_{t} = \\alpha_{t} \\cdot \\omega_{t} + + is used by the optimizer. + + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_params (list): Warmup parameters. + last_step (int): The index of last step. Default: -1. """ def __init__(self, optimizer, warmup_params, last_step=-1): @@ -28,7 +37,7 @@ def __init__(self, optimizer, warmup_params, last_step=-1): def state_dict(self): """Returns the state of the warmup scheduler as a :class:`dict`. - It contains an entry for every variable in self.__dict__ which + It contains an entry for every variable in :attr:`self.__dict__` which is not the optimizer. """ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} @@ -36,17 +45,20 @@ def state_dict(self): def load_state_dict(self, state_dict): """Loads the warmup scheduler's state. - Arguments: - state_dict (dict): warmup scheduler state. Should be an object returned + Args: + state_dict (dict): Warmup scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) def dampen(self, step=None): - """Dampen the learning rates. + """Dampens the learning rate. - Arguments: - step (int): The index of current step. (Default: None) + It is not recommended to explicitly call this method for PyTorch 1.4.0 or later. + Please use the :meth:`dampening` context manager that calls this method correctly. + + Args: + step (int): The index of current step. Default: ``None``. """ if step is None: step = self.last_step + 1 @@ -58,6 +70,32 @@ def dampen(self, step=None): @contextmanager def dampening(self): + """Dampens the learning rate after calling the :meth:`step` method of the learning + rate scheduler. + + The :meth:`step` method calls must be placed in a suite of the ``with`` statement having + the :meth:`dampening` context manager. + + Examples: + >>> # For no LR scheduler + >>> with warmup_scheduler.dampening(): + >>> pass + + >>> # For a single LR scheduler + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler.step() + + >>> # To chain two LR schedulers + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler1.step() + >>> lr_scheduler2.step() + + >>> # To delay an LR scheduler + >>> iteration = warmup_scheduler.last_step + 1 + >>> with warmup_scheduler.dampening(): + >>> if iteration >= warmup_period: + >>> lr_scheduler.step() + """ for group, lr in zip(self.optimizer.param_groups, self.lrs): group['lr'] = lr yield @@ -65,6 +103,16 @@ def dampening(self): self.dampen() def warmup_factor(self, step, **params): + """Returns the warmup factor :math:`\\omega_{t}` at an iteration :math:`t`. + + :meth:`dampen` uses this method to get the warmup factor for each parameter group. + It is unnecessary to explicitly call this method. + + Args: + step (int): The index of current step. + params (dict): The warmup parameters. For details, refer to the arguments of + each subclass method. + """ raise NotImplementedError @@ -98,10 +146,32 @@ def get_warmup_params(warmup_period, group_count): class LinearWarmup(BaseWarmup): """Linear warmup schedule. - Arguments: - optimizer (Optimizer): an instance of a subclass of Optimizer - warmup_period (int or list): Warmup period - last_step (int): The index of last step. (Default: -1) + The linear warmup schedule uses the warmup factor + + .. math:: + \\omega_{t}^{\\rm linear, \\tau} = \\min \\{ 1, \\frac{1}{\\tau} \\cdot t \\} + + at each iteration :math:`t`, where :math:`\\tau` is the warmup period. + + Args: + optimizer (Optimizer): Wrapped optimizer. :class:`RAdam` is not suitable because of the + warmup redundancy. + warmup_period (int or list[int]): The warmup period :math:`\\tau`. + last_step (int): The index of last step. Default: -1. + + Example: + >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) + >>> warmup_scheduler = LinearWarmup(optimizer, warmup_period=2000) + >>> for batch in dataloader: + >>> optimizer.zero_grad() + >>> loss = ... + >>> loss.backward() + >>> optimizer.step() + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler.step() + + Note: + The warmup schedule must not be initialized before the initialization of the learning rate schedule. """ def __init__(self, optimizer, warmup_period, last_step=-1): @@ -111,16 +181,45 @@ def __init__(self, optimizer, warmup_period, last_step=-1): super(LinearWarmup, self).__init__(optimizer, warmup_params, last_step) def warmup_factor(self, step, warmup_period): + """Returns the warmup factor :math:`\\omega_{t}^{\\rm linear, \\tau}` at an iteration :math:`t`. + + Args: + step (int): The index of current step. + warmup_period (int): The warmup period :math:`\\tau`. + """ return min(1.0, (step+1) / warmup_period) class ExponentialWarmup(BaseWarmup): """Exponential warmup schedule. - Arguments: - optimizer (Optimizer): an instance of a subclass of Optimizer - warmup_period (int or list): Effective warmup period - last_step (int): The index of last step. (Default: -1) + The exponential warmup schedule uses the warmup factor + + .. math:: + \\omega_{t}^{\\rm expo, \\tau} = 1 - \\exp \\left( - \\frac{1}{\\tau} \\cdot t \\right) + + at each iteration :math:`t`, where the constant :math:`\\tau` is analogous to + a linear warmup period. + + Args: + optimizer (Optimizer): Wrapped optimizer. :class:`RAdam` is not suitable because of the + warmup redundancy. + warmup_period (int or list[int]): The constant :math:`\\tau` analogous to a linear warmup period. + last_step (int): The index of last step. Default: -1. + + Example: + >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) + >>> warmup_scheduler = ExponentialWarmup(optimizer, warmup_period=1000) + >>> for batch in dataloader: + >>> optimizer.zero_grad() + >>> loss = ... + >>> loss.backward() + >>> optimizer.step() + >>> with warmup_scheduler.dampening(): + >>> lr_scheduler.step() + + Note: + The warmup schedule must not be initialized before the initialization of the learning rate schedule. """ def __init__(self, optimizer, warmup_period, last_step=-1): @@ -130,4 +229,10 @@ def __init__(self, optimizer, warmup_period, last_step=-1): super(ExponentialWarmup, self).__init__(optimizer, warmup_params, last_step) def warmup_factor(self, step, warmup_period): + """Returns the warmup factor :math:`\\omega_{t}^{\\rm expo, \\tau}` at an iteration :math:`t`. + + Args: + step (int): The index of current step. + warmup_period (int): The constant :math:`\\tau` analogous to a linear warmup period. + """ return 1.0 - math.exp(-(step+1) / warmup_period) diff --git a/pytorch_warmup/radam.py b/pytorch_warmup/radam.py index 6065e6f..54a64c7 100644 --- a/pytorch_warmup/radam.py +++ b/pytorch_warmup/radam.py @@ -3,16 +3,35 @@ def rho_inf_fn(beta2): + """Returns the constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. + + Args: + beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. + """ return 2.0 / (1 - beta2) - 1 def rho_fn(t, beta2, rho_inf): + """Returns the value of the function of the RAdam algorithm, :math:`\\rho_{t}`, + at an iteration :math:`t`. + + Args: + t (int): The iteration :math:`t`. + beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. + rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. + """ b2t = beta2 ** t rho_t = rho_inf - 2 * t * b2t / (1 - b2t) return rho_t def get_offset(beta2, rho_inf): + """Returns the minimal offset :math:`\\delta`. + + Args: + beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. + rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. + """ if not beta2 > 0.6: raise ValueError('beta2 ({}) must be greater than 0.6'.format(beta2)) offset = 1 @@ -29,9 +48,77 @@ class RAdamWarmup(BaseWarmup): `On the adequacy of untuned warmup for adaptive optimization