Skip to content

Commit

Permalink
deepspeed docs
Browse files Browse the repository at this point in the history
  • Loading branch information
MikhailKardash authored and azhou-determined committed Oct 31, 2024
1 parent 983a8ab commit 66238ed
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 61 deletions.
235 changes: 235 additions & 0 deletions docs/model-dev-guide/api-guides/apis-howto/deepspeed/deepspeed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,241 @@ profiling batches 3 and 4.
rendering times for TensorBoard and memory issues. For long-running experiments, it is
recommended to configure a profiling schedule.

*******************
DeepSpeed Trainer
*******************

With the DeepSpeed Trainer API, you can implement and iterate on model training code locally before
running on cluster. When you are satisfied with your model code, you configure and submit the code
on cluster.

The DeepSpeed Trainer API lets you do the following:

- Work locally, iterating on your model code.
- Debug models in your favorite debug environment (e.g., directly on your machine, IDE, or Jupyter
notebook).
- Run training scripts without needing to use an experiment configuration file.
- Load previously saved checkpoints directly into your model.

Initializing the Trainer
========================

After defining the PyTorch Trial, initialize the trial and the trainer.
:meth:`~determined.pytorch.deepspeed.init` returns a
:class:`~determined.pytorch.deepspeed.DeepSpeedTrialContext` for instantiating
:class:`~determined.pytorch.deepspeed.DeepSpeedTrial`. Initialize
:class:`~determined.pytorch.deepspeed.Trainer` with the trial and context.

.. code:: python
from determined.pytorch import deepspeed as det_ds
def main():
with det_ds.init() as train_context:
trial = MyTrial(train_context)
trainer = det_ds.Trainer(trial, train_context)
if __name__ == "__main__":
# Configure logging
logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT)
main()
Training is configured with a call to :meth:`~determined.pytorch.deepspeed.Trainer.fit` with
training loop arguments, such as checkpointing periods, validation periods, and checkpointing
policy.

.. code:: diff
from determined import pytorch
from determined.pytorch import deepspeed as det_ds
def main():
with det_ds.init() as train_context:
trial = MyTrial(train_context)
trainer = det_ds.Trainer(trial, train_context)
+ trainer.fit(
+ max_length=pytorch.Epoch(10),
+ checkpoint_period=pytorch.Batch(100),
+ validation_period=pytorch.Batch(100),
+ checkpoint_policy="all"
+ )
if __name__ == "__main__":
# Configure logging
logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT)
main()
Run Your Training Script Locally
================================

Run training scripts locally without submitting to a cluster or defining an experiment configuration
file.

.. code:: python
from determined import pytorch
from determined.pytorch import deepspeed as det_ds
def main():
with det_ds.init() as train_context:
trial = MyTrial(train_context)
trainer = det_ds.Trainer(trial, train_context)
trainer.fit(
max_length=pytorch.Epoch(10),
checkpoint_period=pytorch.Batch(100),
validation_period=pytorch.Batch(100),
checkpoint_policy="all",
)
if __name__ == "__main__":
# Configure logging
logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT)
main()
You can run this Python script directly (``python3 train.py``), or in a Jupyter notebook. This code
will train for ten epochs, and checkpoint and validate every 100 batches.

Local Distributed Training
==========================

Local training can utilize multiple GPUs on a single node with a few modifications to the above
code.

.. code:: diff
import deepspeed
def main():
+ # Initialize distributed backend before det_ds.init()
+ deepspeed.init_distributed()
+ # Set flag used by internal PyTorch training loop
+ os.environ["DET_MANUAL_INIT_DISTRIBUTED"] = "true"
+ # Initialize DistributedContext
with det_ds.init(
+ distributed=core.DistributedContext.from_deepspeed()
) as train_context:
trial = MyTrial(train_context)
trainer = det_ds.Trainer(trial, train_context)
trainer.fit(
max_length=pytorch.Epoch(10),
checkpoint_period=pytorch.Batch(100),
validation_period=pytorch.Batch(100),
checkpoint_policy="all"
)
This code can be directly invoked with your distributed backend's launcher: ``deepspeed --num_gpus=4
trainer.py --deepspeed --deepspeed_config ds_config.json``

Test Mode
=========

Trainer accepts a test_mode parameter which, if true, trains and validates your training code for
only one batch, checkpoints, then exits. This is helpful for debugging code or writing automated
tests around your model code.

.. code:: diff
trainer.fit(
max_length=pytorch.Epoch(10),
checkpoint_period=pytorch.Batch(100),
validation_period=pytorch.Batch(100),
+ test_mode=True
)
Prepare Your Training Code for Deploying to a Determined Cluster
================================================================

Once you are satisfied with the results of training the model locally, you submit the code to a
cluster. This example allows for distributed training locally and on cluster without having to make
code changes.

Example workflow of frequent iterations between local debugging and cluster deployment:

.. code:: diff
def main():
+ local = det.get_cluster_info() is None
+ if local:
+ # Local: configure local distributed training.
+ deepspeed.init_distributed()
+ # Set flag used by internal PyTorch training loop
+ os.environ["DET_MANUAL_INIT_DISTRIBUTED"] = "true"
+ distributed_context = core.DistributedContext.from_deepspeed()
+ latest_checkpoint = None
+ else:
+ # On-cluster: Determined will automatically detect distributed context.
+ distributed_context = None
+ # On-cluster: configure the latest checkpoint for pause/resume training functionality.
+ latest_checkpoint = det.get_cluster_info().latest_checkpoint
+ with det_ds.init(
+ distributed=distributed_context
) as train_context:
trial = DCGANTrial(train_context)
trainer = det_ds.Trainer(trial, train_context)
trainer.fit(
max_length=pytorch.Epoch(11),
checkpoint_period=pytorch.Batch(100),
validation_period=pytorch.Batch(100),
+ latest_checkpoint=latest_checkpoint,
)
To run Trainer API solely on-cluster, the code is much simpler:

.. code:: python
def main():
with det_ds.init() as train_context:
trial_inst = gan_model.DCGANTrial(train_context)
trainer = det_ds.Trainer(trial_inst, train_context)
trainer.fit(
max_length=pytorch.Epoch(11),
checkpoint_period=pytorch.Batch(100),
validation_period=pytorch.Batch(100),
latest_checkpoint=det.get_cluster_info().latest_checkpoint,
)
Submit Your Trial for Training on Cluster
=========================================

To run your experiment on cluster, you'll need to create an experiment configuration (YAML) file.
Your experiment configuration file must contain searcher configuration and entrypoint.

.. code:: python
name: dcgan_deepspeed_mnist
searcher:
name: single
metric: validation_loss
resources:
slots_per_trial: 2
entrypoint: python3 -m determined.launch.deepspeed python3 train.py
Submit the trial to the cluster:

.. code:: bash
det e create det.yaml .
If your training code needs to read some values from the experiment configuration,
``pytorch.deepspeed.init()`` accepts an ``exp_conf`` argument which allows calling
``context.get_experiment_config()`` from ``DeepSpeedTrialContext``.

Profiling
=========

When training on cluster, you can enable the system metrics profiler by adding a parameter to your
``fit()`` call:

.. code:: diff
trainer.fit(
...,
+ profiling_enabled=True
)
*****************************
Known DeepSpeed Constraints
*****************************
Expand Down
13 changes: 13 additions & 0 deletions docs/reference/training/api-deepspeed-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,16 @@ documentation):
- :ref:`determined.pytorch.samplers <pytorch-samplers>`
- :ref:`determined.pytorch.MetricReducer <pytorch-metric-reducer>`
- :ref:`determined.pytorch.PyTorchCallback <pytorch-callbacks>`

******************************************
``determined.pytorch.deepspeed.Trainer``
******************************************

.. autoclass:: determined.pytorch.deepspeed.Trainer
:members:

*****************************************
``determined.pytorch.deepspeed.init()``
*****************************************

.. autofunction:: determined.pytorch.deepspeed.init
8 changes: 7 additions & 1 deletion examples/deepspeed/dcgan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ After installing docker and pulling an image, users can launch a container via

Install necessary dependencies via `pip install determined mpi4py`

Then, run the following command:
Then, run the following command if running on a single node and GPU:
```
python trainer.py
```
For multiple nodes GPUs, use the following:
```
deepspeed --num_nodes=<node_count> --num_gpus=<gpu_count> trainer.py --deepspeed --deepspeed_config ds_config.json
```
Where `num_nodes` corresponds to the number of nodes on your local cluster and `num_gpus` corresponds to
the number of GPUs per node.

Any additional configs can be specified in `mnist.yaml` and `ds_config.json` accordingly.

Expand Down
18 changes: 9 additions & 9 deletions examples/deepspeed/dcgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, context: det_ds.DeepSpeedTrialContext,
self.discriminator = self.context.wrap_model_engine(discriminator)
self.fixed_noise = self.context.to_device(
torch.randn(
self.context.train_micro_batch_size_per_gpu, self.hparams["noise_length"], 1, 1
self.context.get_train_micro_batch_size_per_gpu(), self.hparams["noise_length"], 1, 1
)
)
self.criterion = nn.BCELoss()
Expand All @@ -62,7 +62,7 @@ def _get_noise(self, dtype: torch.dtype) -> torch.Tensor:
torch.Tensor,
self.context.to_device(
torch.randn(
self.context.train_micro_batch_size_per_gpu,
self.context.get_train_micro_batch_size_per_gpu(),
self.hparams["noise_length"],
1,
1,
Expand Down Expand Up @@ -93,7 +93,7 @@ def train_batch(
else:
dtype = torch.float32
real_label, fake_label = self._get_label_constants(
self.context.train_micro_batch_size_per_gpu, dtype
self.context.get_train_micro_batch_size_per_gpu(), dtype
)
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
Expand All @@ -106,7 +106,7 @@ def train_batch(
D_x = 0.0
D_G_z1 = 0.0
fake_sample_count = (
self.context.train_micro_batch_size_per_gpu * self.gradient_accumulation_steps
self.context.get_train_micro_batch_size_per_gpu() * self.gradient_accumulation_steps
)

for i in range(self.gradient_accumulation_steps):
Expand All @@ -132,7 +132,7 @@ def train_batch(
output = self.discriminator(fake.detach())
errD_fake = self.criterion(output, fake_label)
self.discriminator.backward(errD_fake)
errD_fake_sum += errD_fake * self.context.train_micro_batch_size_per_gpu
errD_fake_sum += errD_fake * self.context.get_train_micro_batch_size_per_gpu()
D_G_z1 += output.sum().item()
# update
self.discriminator.step()
Expand All @@ -153,7 +153,7 @@ def train_batch(
output = self.discriminator(fake)
errG = self.criterion(output, real_label) # fake labels are real for generator cost
self.generator.backward(errG)
errG_sum += errG * self.context._train_micro_batch_size_per_gpu
errG_sum += errG * self.context.get_train_micro_batch_size_per_gpu()
D_G_z2_sum += output.sum().item()
self.generator.step()

Expand Down Expand Up @@ -188,7 +188,7 @@ def build_training_data_loader(self) -> Any:
dataset = data.get_dataset(self.data_config)
return DataLoader(
dataset,
batch_size=self.context.train_micro_batch_size_per_gpu,
batch_size=self.context.get_train_micro_batch_size_per_gpu(),
shuffle=True,
num_workers=int(self.hparams["data_workers"]),
)
Expand All @@ -200,9 +200,9 @@ def build_validation_data_loader(self) -> Any:
dataset,
list(
range(
self.context.train_micro_batch_size_per_gpu
self.context.get_train_micro_batch_size_per_gpu()
* self.context.distributed.get_size()
)
),
)
return DataLoader(dataset, batch_size=self.context.train_micro_batch_size_per_gpu)
return DataLoader(dataset, batch_size=self.context.get_train_micro_batch_size_per_gpu())
Loading

0 comments on commit 66238ed

Please sign in to comment.