diff --git a/examples/deepspeed/dcgan/README.md b/examples/deepspeed/dcgan/README.md new file mode 100644 index 00000000000..f0b9811b9c9 --- /dev/null +++ b/examples/deepspeed/dcgan/README.md @@ -0,0 +1,49 @@ +# DeepSpeed CIFAR Example +This example is adapted from the +[DCGAN example in the DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/training/gan) +repository. It is intended to demonstrate a simple usecase of DeepSpeed with Determined. + +## Files +* **model.py**: The DCGANTrial definition. +* **gan_model.py**: Network definitions for generator and discriminator. +* **data.py**: Dataset loading/downloading code. + +### Configuration Files +* **ds_config.json**: The DeepSpeed config file. +* **mnist.yaml**: Determined config to train the model on mnist on a cluster. + +## Data +This repo supports the same datasets as the original example: `["imagenet", "lfw", "lsun", "cifar10", "mnist", "fake", "celeba"]`. The `cifar10` and `mnist` datasets will be downloaded as needed, whereas the rest must be mounted on the agent. For `lsun`, the `data_config.classes` setting must be set. The `folder` dataset can be used to load an arbitrary torchvision `ImageFolder` that is mounted on the agent. + +## To Run Locally + +It is recommended to run this from within one of our agent docker images, found at +https://hub.docker.com/r/determinedai/pytorch-ngc/tags + +After installing docker and pulling an image, users can launch a container via +`docker run --gpus=all -v ~path/to/repo:/src/proj -it ` + +Install necessary dependencies via `pip install determined mpi4py` + +Then, run the following command: +``` +python trainer.py +``` + +Any additional configs can be specified in `mnist.yaml` and `ds_config.json` accordingly. + +## To Run on Cluster +If you have not yet installed Determined, installation instructions can be found +under `docs/install-admin.html` or at https://docs.determined.ai/latest/index.html + +Run the following command: +``` +det experiment create mnist.yaml . +``` +The other configurations can be run by specifying the appropriate configuration file in place +of `mnist.yaml`. + +## Results +Training `mnist` should yield reasonable looking fake digit images on the images tab in TensorBoard after ~5k steps. + +Training `cifar10` does not converge as convincingly, but should look image-like after ~10k steps. diff --git a/examples/deepspeed/dcgan/data.py b/examples/deepspeed/dcgan/data.py new file mode 100644 index 00000000000..c950df584d1 --- /dev/null +++ b/examples/deepspeed/dcgan/data.py @@ -0,0 +1,104 @@ +import contextlib +import os +from typing import cast + +import filelock +import torch +import torchvision.datasets as dset +import torchvision.transforms as transforms + +CHANNELS_BY_DATASET = { + "imagenet": 3, + "folder": 3, + "lfw": 3, + "lsun": 3, + "cifar10": 3, + "mnist": 1, + "fake": 3, + "celeba": 3, +} + + +def get_dataset(data_config: dict) -> torch.utils.data.Dataset: + if data_config.get("dataroot", None) is None: + if str(data_config.get("dataset"),"").lower() != "fake": + raise ValueError('`dataroot` parameter is required for dataset "%s"' + % data_config.get("dataset", "")) + else: + context = contextlib.nullcontext() + else: + # Ensure that only one local process attempts to download/validate datasets at once. + context = filelock.FileLock(os.path.join(data_config["dataroot"], ".lock")) + with context: + if data_config["dataset"] in ["imagenet", "folder", "lfw"]: + # folder dataset + dataset = dset.ImageFolder( + root=data_config["dataroot"], + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.CenterCrop(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + elif data_config["dataset"] == "lsun": + classes = [c + "_train" for c in data_config["classes"].split(",")] + dataset = dset.LSUN( + root=data_config["dataroot"], + classes=classes, + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.CenterCrop(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + elif data_config["dataset"] == "cifar10": + dataset = dset.CIFAR10( + root=data_config["dataroot"], + download=True, + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + elif data_config["dataset"] == "mnist": + dataset = dset.MNIST( + root=data_config["dataroot"], + download=True, + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + ] + ), + ) + elif data_config["dataset"] == "fake": + dataset = dset.FakeData( + image_size=(3, data_config["image_size"], data_config["image_size"]), + transform=transforms.ToTensor(), + ) + elif data_config["dataset"] == "celeba": + dataset = dset.ImageFolder( + root=data_config["dataroot"], + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.CenterCrop(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + else: + unknown_dataset_name = data_config["dataset"] + raise Exception(f"Unknown dataset {unknown_dataset_name}") + return cast(torch.utils.data.Dataset, dataset) diff --git a/examples/deepspeed/dcgan/ds_config.json b/examples/deepspeed/dcgan/ds_config.json new file mode 100644 index 00000000000..708952b50b2 --- /dev/null +++ b/examples/deepspeed/dcgan/ds_config.json @@ -0,0 +1,15 @@ +{ + "train_batch_size": 64, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "betas": [ + 0.5, + 0.999 + ], + "eps": 1e-8 + } + }, + "steps_per_print": 10 +} diff --git a/examples/deepspeed/dcgan/gan_model.py b/examples/deepspeed/dcgan/gan_model.py new file mode 100644 index 00000000000..97ed726f45b --- /dev/null +++ b/examples/deepspeed/dcgan/gan_model.py @@ -0,0 +1,73 @@ +from typing import cast + +import torch +import torch.nn as nn + + +def weights_init(m: nn.Module) -> None: + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(cast(torch.Tensor, m.weight.data), 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(cast(torch.Tensor, m.weight.data), 1.0, 0.02) + nn.init.constant_(cast(torch.Tensor, m.bias.data), 0) + + +class Generator(nn.Module): + def __init__(self, ngf: int, nc: int, nz: int) -> None: + super(Generator, self).__init__() # type: ignore + self.main = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), + nn.BatchNorm2d(ngf * 8), # type: ignore + nn.ReLU(True), + # state size. (ngf*8) x 4 x 4 + nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 4), # type: ignore + nn.ReLU(True), + # state size. (ngf*4) x 8 x 8 + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 2), # type: ignore + nn.ReLU(True), + # state size. (ngf*2) x 16 x 16 + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), # type: ignore + nn.ReLU(True), + # state size. (ngf) x 32 x 32 + nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), + nn.Tanh() # type: ignore + # state size. (nc) x 64 x 64 + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.main(input) + return cast(torch.Tensor, output) + + +class Discriminator(nn.Module): + def __init__(self, ndf: int, nc: int) -> None: + super(Discriminator, self).__init__() # type: ignore + self.main = nn.Sequential( + # input is (nc) x 64 x 64 + nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf) x 32 x 32 + nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 2), # type: ignore + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*2) x 16 x 16 + nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 4), # type: ignore + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*4) x 8 x 8 + nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 8), # type: ignore + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*8) x 4 x 4 + nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), + nn.Sigmoid(), # type: ignore + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.main(input) + return cast(torch.Tensor, output.view(-1, 1).squeeze(1)) diff --git a/examples/deepspeed/dcgan/mnist.yaml b/examples/deepspeed/dcgan/mnist.yaml new file mode 100644 index 00000000000..fb996c55532 --- /dev/null +++ b/examples/deepspeed/dcgan/mnist.yaml @@ -0,0 +1,33 @@ +name: dcgan_deepspeed_mnist +data: + dataroot: /data + dataset: mnist + image_size: 64 +hyperparameters: + deepspeed_config: ds_config.json + noise_length: 100 + generator_width_base: 64 + discriminator_width_base: 64 + data_workers: 16 +environment: + environment_variables: + - NCCL_DEBUG=INFO + - NCCL_SOCKET_IFNAME=ens,eth,ib + image: determinedai/pytorch-ngc-dev:0736b6d +bind_mounts: + - host_path: /tmp + container_path: /data +resources: + slots_per_trial: 2 +searcher: + name: single + metric: no_validation_metric +min_validation_period: + batches: 0 +entrypoint: + - python3 + - -m + - determined.launch.deepspeed + - python3 + - trainer.py +max_restarts: 0 diff --git a/examples/deepspeed/dcgan/model.py b/examples/deepspeed/dcgan/model.py new file mode 100644 index 00000000000..8ceab93dc6a --- /dev/null +++ b/examples/deepspeed/dcgan/model.py @@ -0,0 +1,208 @@ +import logging +from typing import Any, Dict, Iterator, Optional, Tuple, Union, cast + +import data +import deepspeed +import torch +import torch.nn as nn +import torch.utils.data +import torchvision +from gan_model import Discriminator, Generator, weights_init + +from determined.pytorch import DataLoader, TorchData +from determined.pytorch import deepspeed as det_ds + +REAL_LABEL = 1 +FAKE_LABEL = 0 + + +class DCGANTrial(det_ds.DeepSpeedTrial): + def __init__(self, context: det_ds.DeepSpeedTrialContext, + hparams: dict, data_config: dict) -> None: + self.context = context + self.hparams = hparams + self.data_config = data_config + self.logger = self.context.get_tensorboard_writer() + num_channels = data.CHANNELS_BY_DATASET[self.data_config["dataset"]] + gen_net = Generator( + self.hparams["generator_width_base"], num_channels, self.hparams["noise_length"] + ) + gen_net.apply(weights_init) + disc_net = Discriminator(self.hparams["discriminator_width_base"], num_channels) + disc_net.apply(weights_init) + gen_parameters = filter(lambda p: p.requires_grad, gen_net.parameters()) + disc_parameters = filter(lambda p: p.requires_grad, disc_net.parameters()) + ds_config = det_ds.overwrite_deepspeed_config( + self.hparams["deepspeed_config"], self.hparams.get("overwrite_deepspeed_args", {}) + ) + generator, _, _, _ = deepspeed.initialize( + model=gen_net, model_parameters=gen_parameters, config=ds_config + ) + discriminator, _, _, _ = deepspeed.initialize( + model=disc_net, model_parameters=disc_parameters, config=ds_config + ) + + self.generator = self.context.wrap_model_engine(generator) + self.discriminator = self.context.wrap_model_engine(discriminator) + self.fixed_noise = self.context.to_device( + torch.randn( + self.context.train_micro_batch_size_per_gpu, self.hparams["noise_length"], 1, 1 + ) + ) + self.criterion = nn.BCELoss() + self.fp16 = generator.fp16_enabled() + self.gradient_accumulation_steps = generator.gradient_accumulation_steps() + # Manually perform gradient accumulation. + if self.gradient_accumulation_steps > 1: + logging.info("Disabling automatic gradient accumulation.") + self.context.disable_auto_grad_accumulation() + + def _get_noise(self, dtype: torch.dtype) -> torch.Tensor: + return cast( + torch.Tensor, + self.context.to_device( + torch.randn( + self.context.train_micro_batch_size_per_gpu, + self.hparams["noise_length"], + 1, + 1, + dtype=dtype, + ) + ), + ) + + def _get_label_constants( + self, batch_size: int, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + real_label = cast( + torch.Tensor, + self.context.to_device(torch.full((batch_size,), REAL_LABEL, dtype=dtype)), + ) + fake_label = cast( + torch.Tensor, + self.context.to_device(torch.full((batch_size,), FAKE_LABEL, dtype=dtype)), + ) + return real_label, fake_label + + def train_batch( + self, iter_dataloader: Optional[Iterator[TorchData]], epoch_idx: int, batch_idx: int + ) -> Union[torch.Tensor, Dict[str, Any]]: + assert iter_dataloader is not None + if self.fp16: + dtype = torch.float16 + else: + dtype = torch.float32 + real_label, fake_label = self._get_label_constants( + self.context.train_micro_batch_size_per_gpu, dtype + ) + ############################ + # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) + ########################### + self.discriminator.zero_grad() + + real_sample_count = 0 + errD_real_sum = 0.0 + errD_fake_sum = 0.0 + D_x = 0.0 + D_G_z1 = 0.0 + fake_sample_count = ( + self.context.train_micro_batch_size_per_gpu * self.gradient_accumulation_steps + ) + + for i in range(self.gradient_accumulation_steps): + # Note: at end of epoch, may receive a batch of size smaller than train_micro_batch_size_per_gpu. + # In that case, we end up training on more fake examples than real examples. + # train with real + real, _ = self.context.to_device(next(iter_dataloader)) + real = cast(torch.Tensor, real) + actual_batch_size = real.shape[0] + real_sample_count += actual_batch_size + if self.fp16: + real = real.half() + output = self.discriminator(real) + # For edge-case small batches, must cut real_label size to match. + errD_real = self.criterion(output, real_label[:actual_batch_size]) + self.discriminator.backward(errD_real) + # Undo averaging so we can re-average at end when reporting metrics. + errD_real_sum += errD_real * actual_batch_size + D_x += output.sum().item() + # train with fake + noise = self._get_noise(dtype) + fake = self.generator(noise) + output = self.discriminator(fake.detach()) + errD_fake = self.criterion(output, fake_label) + self.discriminator.backward(errD_fake) + errD_fake_sum += errD_fake * self.context.train_micro_batch_size_per_gpu + D_G_z1 += output.sum().item() + # update + self.discriminator.step() + D_x /= real_sample_count + D_G_z1 /= fake_sample_count + errD = (errD_real_sum / real_sample_count) + (errD_fake_sum / fake_sample_count) + ############################ + # (2) Update G network: maximize log(D(G(z))) + ########################### + self.generator.zero_grad() + D_G_z2_sum = 0.0 + errG_sum = 0.0 + for i in range(self.gradient_accumulation_steps): + if i > 0: + # Must repeat forward pass of generator for accumulation steps beyond the first. + noise = self._get_noise(dtype) + fake = self.generator(noise) + output = self.discriminator(fake) + errG = self.criterion(output, real_label) # fake labels are real for generator cost + self.generator.backward(errG) + errG_sum += errG * self.context._train_micro_batch_size_per_gpu + D_G_z2_sum += output.sum().item() + self.generator.step() + + if batch_idx % 100 == 0: + fake = self.generator(self.fixed_noise) + denormalized_real = (real + 1) / 2 + denormalized_fake = (fake + 1) / 2 + self.logger.add_image( + "real_images", torchvision.utils.make_grid(denormalized_real), batch_idx + ) + self.logger.add_image( + "fake_images", torchvision.utils.make_grid(denormalized_fake), batch_idx + ) + + return { + "errD": errD, + "errG": errG_sum / fake_sample_count, + "D_x": D_x, + "D_G_z1": D_G_z1, + "D_G_z2": D_G_z2_sum / fake_sample_count, + } + + def evaluate_batch( + self, dataloader_iter: Optional[Iterator[TorchData]], batch_idx: int + ) -> Dict[str, Any]: + # TODO: We could add an evaluation metric like FID here. + assert dataloader_iter is not None + next(dataloader_iter) + return {"no_validation_metric": 0.0} + + def build_training_data_loader(self) -> Any: + dataset = data.get_dataset(self.data_config) + return DataLoader( + dataset, + batch_size=self.context.train_micro_batch_size_per_gpu, + shuffle=True, + num_workers=int(self.hparams["data_workers"]), + ) + + def build_validation_data_loader(self) -> Any: + dataset = data.get_dataset(self.data_config) + # Since we're not doing validation, limit to single batch. + dataset = torch.utils.data.Subset( + dataset, + list( + range( + self.context.train_micro_batch_size_per_gpu + * self.context.distributed.get_size() + ) + ), + ) + return DataLoader(dataset, batch_size=self.context.train_micro_batch_size_per_gpu) diff --git a/examples/deepspeed/dcgan/trainer.py b/examples/deepspeed/dcgan/trainer.py new file mode 100644 index 00000000000..1d114430d6f --- /dev/null +++ b/examples/deepspeed/dcgan/trainer.py @@ -0,0 +1,38 @@ +import logging + +import model +import yaml + +import determined as det +from determined import pytorch +from determined.pytorch import deepspeed as det_ds + + +def main(config_file: str, local: bool = True): + info = det.get_cluster_info() + + if local: + # For convenience, use hparams from const.yaml for local mode. + with open(config_file, "r") as f: + experiment_config = yaml.load(f, Loader=yaml.SafeLoader) + hparams = experiment_config["hyperparameters"] + data_config = experiment_config["data"] + latest_checkpoint = None + else: + hparams = info.trial.hparams + data_config = info.trial._config["data"] + latest_checkpoint = ( + info.latest_checkpoint + ) # (Optional) Configure checkpoint for pause/resume functionality. + + with det_ds.init() as train_context: + trial = model.DCGANTrial(train_context, hparams, data_config) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(200), latest_checkpoint=latest_checkpoint) + + +if __name__ == "__main__": + local = det.get_cluster_info() is None + # Configure logging + logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) + main(config_file="mnist.yaml", local=local) diff --git a/examples/deepspeed/gpt_neox/det_utils.py b/examples/deepspeed/gpt_neox/det_utils.py index 608d30c7cfd..3a6eac44f1c 100644 --- a/examples/deepspeed/gpt_neox/det_utils.py +++ b/examples/deepspeed/gpt_neox/det_utils.py @@ -30,7 +30,7 @@ def get_neox_args(context): "checkpoint_factor": exp_config["min_validation_period"]["batches"], "eval_interval": exp_config["min_validation_period"]["batches"], "hostfile": os.environ.get("DET_DEEPSPEED_HOSTFILE_PATH"), - "seed": context.env.trial_seed, + "seed": context.get_trial_seed(), } ) for k, v in overwrite_values.items(): diff --git a/harness/determined/exec/harness.py b/harness/determined/exec/harness.py index ae950891e13..188855e7399 100644 --- a/harness/determined/exec/harness.py +++ b/harness/determined/exec/harness.py @@ -38,8 +38,13 @@ def main(train_entrypoint: str) -> int: # We can't import pytorch directly because if running TfKerasTrials with an image that contains # both torch and keras, keras will throw exceptions due to unexpected CUDNN library versions. - if hasattr(det, "pytorch") and issubclass(trial_class, det.pytorch.PyTorchTrial): - return _run_pytorch_trial(trial_class, info) + if hasattr(det, "pytorch"): + if hasattr(det.pytorch, "deepspeed") and issubclass( + trial_class, det.pytorch.deepspeed.DeepSpeedTrial + ): + return _run_deepspeed_trial(trial_class, info) + elif issubclass(trial_class, det.pytorch.PyTorchTrial): + return _run_pytorch_trial(trial_class, info) # TODO: Don't include EnvContext object in the future high-level APIs for PyTorch or Keras. # It was natural to create this big-blob-of-config object, but it was a mistake to pass it into @@ -194,6 +199,58 @@ def _run_pytorch_trial( return 0 +def _run_deepspeed_trial( + trial_class: "Type[det.pytorch.deepspeed.DeepSpeedTrial]", + info: det.ClusterInfo, +) -> int: + from determined import pytorch + from determined.pytorch import deepspeed as det_ds + + det.common.set_logger(info.trial._debug) + + logger.debug("Starting harness.") + + with det_ds.init( + hparams=info.trial.hparams, + exp_conf=info.trial._config, + ) as train_context: + trial_inst = trial_class(train_context) + + if train_context.distributed.size > 1 and not train_context.distributed.rank == 0: + log_level = logging.DEBUG if info.trial._debug else logging.WARNING + logging.getLogger().setLevel(log_level) + + logger.info( + f"Creating {det_ds.DeepSpeedTrialController.__name__} with {trial_class.__name__}." + ) + + trainer = det_ds.Trainer(trial_inst, train_context) + + if "global_batch_size" in info.trial.hparams: + global_batch_size = int(info.trial.hparams["global_batch_size"]) # type: Optional[int] + else: + global_batch_size = None + + trainer.fit( + checkpoint_period=pytorch.TrainUnit._from_values( + **info.trial._config["min_checkpoint_period"], + global_batch_size=global_batch_size, + ), + validation_period=pytorch.TrainUnit._from_values( + **info.trial._config["min_validation_period"], + global_batch_size=global_batch_size, + ), + reporting_period=pytorch.Batch(info.trial._config["scheduling_unit"]), + checkpoint_policy=info.trial._config["checkpoint_policy"], + latest_checkpoint=info.latest_checkpoint, + step_zero_validation=info.trial._config["perform_initial_validation"], + test_mode=False, + profiling_enabled=bool(info.trial._config["profiling"]["enabled"]), + ) + + return 0 + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("train_entrypoint") diff --git a/harness/determined/pytorch/__init__.py b/harness/determined/pytorch/__init__.py index dbf60ce0316..4c055768abf 100644 --- a/harness/determined/pytorch/__init__.py +++ b/harness/determined/pytorch/__init__.py @@ -24,17 +24,20 @@ _convert_metrics_to_numpy, _log_tb_metrics, ) +from determined.pytorch._trainer_utils import ( + Batch, + Epoch, + _ShouldExit, + _TrainBoundary, + _TrainBoundaryType, + TrainUnit, + _TrialState, +) from determined.pytorch._experimental import PyTorchExperimentalContext from determined.pytorch._pytorch_context import PyTorchTrialContext from determined.pytorch._pytorch_trial import ( PyTorchTrial, _PyTorchTrialController, - TrainUnit, - _TrainBoundary, - _TrainBoundaryType, - _TrialState, - Batch, - Epoch, ) from determined.pytorch._load import CheckpointLoadContext, load_trial_from_checkpoint_path from determined.pytorch._trainer import init, Trainer diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index b82eb45ecdc..b431ed4d20c 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -1,6 +1,5 @@ import abc import contextlib -import enum import inspect import json import logging @@ -10,7 +9,6 @@ import sys import time import warnings -from collections import abc as col_abc from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union import numpy as np @@ -40,152 +38,14 @@ def dataloader_next(dataloader_iter: Iterator) -> Iterator: yield batch -class TrainUnit: - """ - TrainUnit is the base class for the supported training units (Batch, Epoch) containing - the value of unit, where the value can be an int or an implementable collections.abc.Container. - - TrainUnits are used to define periodic training behavior such as checkpointing and validating. - - int values are treated as periods, e.g. Batch(100) will checkpoint/validate every 100 batches. - collections.abc.Container values are treated as schedules, e.g. Batch(1,5,10) will - checkpoint/validate on batches 1, 5, and 10. - """ - - def __init__(self, value: Union[int, col_abc.Container]): - self.value = value - - @staticmethod - def _from_searcher_unit( - length: int, unit: Optional[core.Unit], global_batch_size: Optional[int] = None - ) -> "TrainUnit": - if unit == core.Unit.EPOCHS: - return Epoch(length) - elif unit == core.Unit.RECORDS: - if global_batch_size is None: - raise ValueError("global_batch_size required for searcher unit Records.") - return Batch._from_records(length, global_batch_size) - elif unit == core.Unit.BATCHES: - return Batch(length) - else: - raise ValueError(f"unrecognized searcher unit {unit}") - - @staticmethod - def _from_values( - batches: Optional[int] = None, - records: Optional[int] = None, - epochs: Optional[int] = None, - global_batch_size: Optional[int] = None, - ) -> "TrainUnit": - if sum((batches is not None, records is not None, epochs is not None)) != 1: - raise ValueError(f"invalid config: batches={batches} records={records} epochs={epochs}") - if batches is not None: - if batches < 1: - batches = sys.maxsize - return Batch(batches) - if records is not None: - assert global_batch_size, "global_batch_size is required for RECORD units." - if records < 1: - records = sys.maxsize - return Batch._from_records(records, global_batch_size) - if epochs is not None: - if epochs < 1: - epochs = sys.maxsize - return Epoch(epochs) - - # Make mypy happy - raise ValueError("invalid values") - - def should_stop(self, step_num: int) -> bool: - if isinstance(self.value, int): - return self._divides(step_num) - assert isinstance(self.value, col_abc.Container) - return step_num in self.value - - def _divides(self, steps: int) -> bool: - assert isinstance(steps, int) and isinstance( - self.value, int - ), "_divides can only be called on int types." - # Treat <= 0 values as always step - if self.value < 1: - return True - if steps == 0: - return False - return steps % self.value == 0 - - -class Epoch(TrainUnit): - """ - Epoch step type (e.g. Epoch(1) defines 1 epoch) - """ - - def __str__(self) -> str: - return f"Epoch({self.value})" - - -class Batch(TrainUnit): - """ - Batch step type (e.g. Batch(1) defines 1 batch) - """ - - @staticmethod - def _from_records(records: int, global_batch_size: int) -> "Batch": - return Batch(max(records // global_batch_size, 1)) - - def __str__(self) -> str: - return f"Batch({self.value})" - - -class _TrainBoundaryType(enum.Enum): - CHECKPOINT = "CHECKPOINT" - REPORT = "REPORT" - VALIDATE = "VALIDATE" - TRAIN = "TRAIN" - - -class _TrainBoundary: - def __init__(self, step_type: _TrainBoundaryType, unit: TrainUnit): - self.step_type = step_type - self.unit = unit - self.limit_reached = False - - -class ShouldExit(Exception): - """ - ShouldExit breaks out of the top-level train loop from inside function calls. - """ - - def __init__(self, skip_exit_checkpoint: bool = False): - self.skip_exit_checkpoint = skip_exit_checkpoint - - -class _TrialState: - def __init__( - self, - trial_id: int = 0, - last_ckpt: int = 0, - step_id: int = 0, - last_val: int = 0, - batches_trained: int = 0, - epochs_trained: int = 0, - ) -> None: - # Store TrialID to distinguish between e.g. pause/restart and continue training. - self.trial_id = trial_id - self.last_ckpt = last_ckpt - self.step_id = step_id - self.last_val = last_val - self.batches_trained = batches_trained - self.epochs_trained = epochs_trained - - class _PyTorchTrialController: def __init__( self, trial_inst: det.LegacyTrial, context: pytorch.PyTorchTrialContext, - checkpoint_period: TrainUnit, - validation_period: TrainUnit, - reporting_period: TrainUnit, + checkpoint_period: pytorch.TrainUnit, + validation_period: pytorch.TrainUnit, + reporting_period: pytorch.TrainUnit, smaller_is_better: bool, steps_completed: int, latest_checkpoint: Optional[str], @@ -194,7 +54,7 @@ def __init__( searcher_metric_name: Optional[str], checkpoint_policy: str, step_zero_validation: bool, - max_length: TrainUnit, + max_length: pytorch.TrainUnit, global_batch_size: Optional[int], profiling_enabled: Optional[bool], ) -> None: @@ -221,7 +81,7 @@ def __init__( self.trial_id = 0 if local_training else self.core_context.train._trial_id # Don't initialize the state here because it will be invalid until we load a checkpoint. - self.state = None # type: Optional[_TrialState] + self.state = None # type: Optional[pytorch._TrialState] self.start_from_batch = steps_completed self.val_from_previous_run = self.core_context.train._get_last_validation() self.step_zero_validation = step_zero_validation @@ -398,7 +258,7 @@ def _checkpoint(self, already_exiting: bool) -> None: except det.InvalidHP: if not already_exiting: self.core_context.train.report_early_exit(core.EarlyExitReason.INVALID_HP) - raise ShouldExit(skip_exit_checkpoint=True) + raise pytorch._ShouldExit(skip_exit_checkpoint=True) raise def _check_evaluate_implementation(self) -> None: @@ -498,17 +358,17 @@ def _step_batch(self) -> None: def _stop_requested(self) -> None: if self.core_context.preempt.should_preempt(): - raise ShouldExit() + raise pytorch._ShouldExit() if self.context.get_stop_requested(): - raise ShouldExit() + raise pytorch._ShouldExit() def _report_training_progress(self) -> None: assert self.state assert isinstance(self.max_length.value, int) - if isinstance(self.max_length, Batch): + if isinstance(self.max_length, pytorch.Batch): progress = self.state.batches_trained / self.max_length.value - elif isinstance(self.max_length, Epoch): + elif isinstance(self.max_length, pytorch.Epoch): progress = self.state.epochs_trained / self.max_length.value else: raise ValueError(f"unexpected train unit type {type(self.max_length)}") @@ -525,12 +385,12 @@ def _validation_is_current(self) -> bool: # State persists validation step in batches return self.state.last_val == self.state.batches_trained - def _steps_until_complete(self, train_unit: TrainUnit) -> int: + def _steps_until_complete(self, train_unit: pytorch.TrainUnit) -> int: assert isinstance(train_unit.value, int), "invalid length type" assert self.state - if isinstance(train_unit, Batch): + if isinstance(train_unit, pytorch.Batch): return train_unit.value - self.state.batches_trained - elif isinstance(train_unit, Epoch): + elif isinstance(train_unit, pytorch.Epoch): return train_unit.value - self.state.epochs_trained else: raise ValueError(f"Unrecognized train unit {train_unit}") @@ -587,7 +447,7 @@ def cleanup_iterator() -> None: self._load(load_path) else: # If we are not loading, initialize a fresh state. - self.state = _TrialState(trial_id=self.trial_id) + self.state = pytorch._TrialState(trial_id=self.trial_id) if self.context.distributed.size > 1 and self.use_horovod: hvd = horovod.hvd @@ -616,24 +476,26 @@ def _run(self) -> None: self._validate() self._train( - length=Batch(1) if self.test_mode else self.max_length, + length=pytorch.Batch(1) if self.test_mode else self.max_length, train_boundaries=[ - _TrainBoundary( - step_type=_TrainBoundaryType.TRAIN, + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.TRAIN, unit=self.max_length, ), - _TrainBoundary( - step_type=_TrainBoundaryType.VALIDATE, unit=self.validation_period + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.VALIDATE, unit=self.validation_period ), - _TrainBoundary( - step_type=_TrainBoundaryType.CHECKPOINT, + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.CHECKPOINT, unit=self.checkpoint_period, ), # Scheduling unit is always configured in batches - _TrainBoundary(step_type=_TrainBoundaryType.REPORT, unit=self.reporting_period), + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.REPORT, unit=self.reporting_period + ), ], ) - except ShouldExit as e: + except pytorch._ShouldExit as e: # Checkpoint unsaved work and exit. if not e.skip_exit_checkpoint and not self._checkpoint_is_current(): self._checkpoint(already_exiting=True) @@ -645,8 +507,8 @@ def _run(self) -> None: return def _train_with_boundaries( - self, training_enumerator: Iterator, train_boundaries: List[_TrainBoundary] - ) -> Tuple[List[_TrainBoundary], List]: + self, training_enumerator: Iterator, train_boundaries: List[pytorch._TrainBoundary] + ) -> Tuple[List[pytorch._TrainBoundary], List]: training_metrics = [] # Start of train step: tell core API and set model mode @@ -677,19 +539,19 @@ def _train_with_boundaries( # Batch complete: check if any training periods have been reached and exit if any for step in train_boundaries: - if isinstance(step.unit, Batch): + if isinstance(step.unit, pytorch.Batch): if step.unit.should_stop(batch_idx + 1): step.limit_reached = True # True epoch based training not supported, detect last batch of epoch to calculate # fully-trained epochs - if isinstance(step.unit, Epoch): + if isinstance(step.unit, pytorch.Epoch): if step.unit.should_stop(epoch_idx + 1): if batch_in_epoch_idx == epoch_len - 1: step.limit_reached = True # Break early after one batch for test mode - if step.step_type == _TrainBoundaryType.TRAIN and self.test_mode: + if step.step_type == pytorch._TrainBoundaryType.TRAIN and self.test_mode: step.limit_reached = True # Exit if any train step limits have been reached @@ -699,7 +561,9 @@ def _train_with_boundaries( # True epoch end return train_boundaries, training_metrics - def _train(self, length: TrainUnit, train_boundaries: List[_TrainBoundary]) -> None: + def _train( + self, length: pytorch.TrainUnit, train_boundaries: List[pytorch._TrainBoundary] + ) -> None: while self._steps_until_complete(length) > 0: train_boundaries, training_metrics = self._train_with_boundaries( self.training_enumerator, train_boundaries @@ -720,18 +584,18 @@ def _train(self, length: TrainUnit, train_boundaries: List[_TrainBoundary]) -> N continue # Train step limits reached, proceed accordingly. - if train_boundary.step_type == _TrainBoundaryType.TRAIN: + if train_boundary.step_type == pytorch._TrainBoundaryType.TRAIN: if self.is_chief and not step_reported: self._report_training_progress() step_reported = True - elif train_boundary.step_type == _TrainBoundaryType.REPORT: + elif train_boundary.step_type == pytorch._TrainBoundaryType.REPORT: if self.is_chief and not step_reported: self._report_training_progress() step_reported = True - elif train_boundary.step_type == _TrainBoundaryType.VALIDATE: + elif train_boundary.step_type == pytorch._TrainBoundaryType.VALIDATE: if not self._validation_is_current(): self._validate() - elif train_boundary.step_type == _TrainBoundaryType.CHECKPOINT: + elif train_boundary.step_type == pytorch._TrainBoundaryType.CHECKPOINT: if not self._checkpoint_is_current(): self._checkpoint(already_exiting=False) @@ -1187,10 +1051,10 @@ def _load_state(self, state: Any) -> None: # If the trial_id doesn't match our current trial id, we're continuing training a previous # trial and should start from a fresh state. if state.get("trial_id") != self.trial_id: - self.state = _TrialState(trial_id=self.trial_id) + self.state = pytorch._TrialState(trial_id=self.trial_id) return - self.state = _TrialState(**state) + self.state = pytorch._TrialState(**state) assert self.state # Detect the case where the final validation we made was against this exact checkpoint. In @@ -1203,10 +1067,10 @@ def _load_state(self, state: Any) -> None: def _load_wlsq_state(self, state: Any) -> None: if state.get("trial_id") != self.trial_id: - self.state = _TrialState(trial_id=self.trial_id) + self.state = pytorch._TrialState(trial_id=self.trial_id) return - self.state = _TrialState( + self.state = pytorch._TrialState( trial_id=state.get("trial_id"), last_ckpt=state.get("last_ckpt"), last_val=state.get("last_val"), diff --git a/harness/determined/pytorch/_trainer_utils.py b/harness/determined/pytorch/_trainer_utils.py new file mode 100644 index 00000000000..254fad6e150 --- /dev/null +++ b/harness/determined/pytorch/_trainer_utils.py @@ -0,0 +1,145 @@ +import enum +import sys +from collections import abc +from typing import Optional, Union + +from determined import core + + +class TrainUnit: + """ + TrainUnit is the base class for the supported training units (Batch, Epoch) containing + the value of unit, where the value can be an int or an implementable collections.abc.Container. + + TrainUnits are used to define periodic training behavior such as checkpointing and validating. + + int values are treated as periods, e.g. Batch(100) will checkpoint/validate every 100 batches. + collections.abc.Container values are treated as schedules, e.g. Batch(1,5,10) will + checkpoint/validate on batches 1, 5, and 10. + """ + + def __init__(self, value: Union[int, abc.Container]): + self.value = value + + @staticmethod + def _from_searcher_unit( + length: int, unit: Optional[core.Unit], global_batch_size: Optional[int] = None + ) -> "TrainUnit": + if unit == core.Unit.EPOCHS: + return Epoch(length) + elif unit == core.Unit.RECORDS: + if global_batch_size is None: + raise ValueError("global_batch_size required for searcher unit Records.") + return Batch._from_records(length, global_batch_size) + elif unit == core.Unit.BATCHES: + return Batch(length) + else: + raise ValueError(f"unrecognized searcher unit {unit}") + + def _to_searcher_unit(self) -> core.Unit: + if isinstance(self, Batch): + return core.Unit.BATCHES + return core.Unit.EPOCHS + + @staticmethod + def _from_values( + batches: Optional[int] = None, + records: Optional[int] = None, + epochs: Optional[int] = None, + global_batch_size: Optional[int] = None, + ) -> "TrainUnit": + if sum((batches is not None, records is not None, epochs is not None)) != 1: + raise ValueError(f"invalid config: batches={batches} records={records} epochs={epochs}") + if batches is not None: + if batches < 1: + batches = sys.maxsize + return Batch(batches) + if records is not None: + assert global_batch_size, "global_batch_size is required for RECORD units." + if records < 1: + records = sys.maxsize + return Batch._from_records(records, global_batch_size) + if epochs is not None: + if epochs < 1: + epochs = sys.maxsize + return Epoch(epochs) + + # Make mypy happy + raise ValueError("invalid values") + + def should_stop(self, step_num: int) -> bool: + if isinstance(self.value, int): + return self._divides(step_num) + assert isinstance(self.value, abc.Container) + return step_num in self.value + + def _divides(self, steps: int) -> bool: + assert isinstance(steps, int) and isinstance( + self.value, int + ), "_divides can only be called on int types." + # Treat <= 0 values as always step + if self.value < 1: + return True + if steps == 0: + return False + return steps % self.value == 0 + + +class Epoch(TrainUnit): + """ + Epoch step type (e.g. Epoch(1) defines 1 epoch) + """ + + pass + + +class Batch(TrainUnit): + """ + Batch step type (e.g. Batch(1) defines 1 batch) + """ + + @staticmethod + def _from_records(records: int, global_batch_size: int) -> "Batch": + return Batch(max(records // global_batch_size, 1)) + + +class _ShouldExit(Exception): + """ + ShouldExit breaks out of the top-level train loop from inside function calls. + """ + + def __init__(self, skip_exit_checkpoint: bool = False): + self.skip_exit_checkpoint = skip_exit_checkpoint + + +class _TrialState: + def __init__( + self, + trial_id: int = 0, + last_ckpt: int = 0, + step_id: int = 0, + last_val: int = 0, + batches_trained: int = 0, + epochs_trained: int = 0, + ) -> None: + # Store TrialID to distinguish between e.g. pause/restart and continue training. + self.trial_id = trial_id + self.last_ckpt = last_ckpt + self.step_id = step_id + self.last_val = last_val + self.batches_trained = batches_trained + self.epochs_trained = epochs_trained + + +class _TrainBoundaryType(enum.Enum): + CHECKPOINT = "CHECKPOINT" + REPORT = "REPORT" + VALIDATE = "VALIDATE" + TRAIN = "TRAIN" + + +class _TrainBoundary: + def __init__(self, step_type: _TrainBoundaryType, unit: TrainUnit): + self.step_type = step_type + self.unit = unit + self.limit_reached = False diff --git a/harness/determined/pytorch/deepspeed/__init__.py b/harness/determined/pytorch/deepspeed/__init__.py index 46b40dc66f7..62cb79dfaaf 100644 --- a/harness/determined/pytorch/deepspeed/__init__.py +++ b/harness/determined/pytorch/deepspeed/__init__.py @@ -8,3 +8,4 @@ overwrite_deepspeed_config, ) from determined.pytorch.deepspeed._deepspeed_trial import DeepSpeedTrial, DeepSpeedTrialController +from determined.pytorch.deepspeed._trainer import init, Trainer diff --git a/harness/determined/pytorch/deepspeed/_deepspeed_context.py b/harness/determined/pytorch/deepspeed/_deepspeed_context.py index dbb80c7f651..b71f44e31da 100644 --- a/harness/determined/pytorch/deepspeed/_deepspeed_context.py +++ b/harness/determined/pytorch/deepspeed/_deepspeed_context.py @@ -1,5 +1,6 @@ import json import logging +import pathlib import time from importlib import util as importutil from typing import Any, Dict, List, Optional, Set, Type, Union, cast @@ -42,7 +43,7 @@ def overwrite_deepspeed_config( return util.merge_dicts(cast(Dict[str, Any], base_ds_config), source_ds_dict) -class DeepSpeedTrialContext(det.TrialContext, pytorch._PyTorchReducerContext): +class DeepSpeedTrialContext(pytorch._PyTorchReducerContext): """Contains runtime information for any Determined workflow that uses the ``DeepSpeedTrial`` API. @@ -65,10 +66,38 @@ class DeepSpeedTrialContext(det.TrialContext, pytorch._PyTorchReducerContext): 5. Disable automatic gradient aggregation for non-pipeline-parallel training. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - det.TrialContext.__init__(self, *args, **kwargs) + def __init__( + self, + core_context: det.core.Context, + trial_seed: Optional[int], + hparams: Optional[Dict], + slots_per_trial: int, + num_gpus: int, + exp_conf: Optional[Dict[str, Any]], + steps_completed: int, + enable_tensorboard_logging: bool = True, + ) -> None: + self._core = core_context + self.distributed = self._core.distributed + pytorch._PyTorchReducerContext.__init__(self, self.distributed.allgather) + self._per_slot_batch_size, self._global_batch_size = ( + util.calculate_batch_sizes( + hparams=hparams, + slots_per_trial=slots_per_trial, + trialname="DeepSpeedTrial", + ) + if hparams and hparams.get("global_batch_size", None) + else (None, None) + ) + self._hparams = hparams + self._num_gpus = num_gpus + self._exp_conf = exp_conf + + self._trial_seed = trial_seed + self._steps_completed = steps_completed + self._init_device() # Track which types we have issued warnings for in to_device(). @@ -85,14 +114,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # The following attributes are initialized during the lifetime of # a DeepSpeedTrialContext. self.models = [] # type: List[deepspeed.DeepSpeedEngine] + self.profiler = None # type: Any self._epoch_len = None # type: Optional[int] self._loss_ids = {} # type: Dict[torch.Tensor, int] self._last_backward_batch_idx = None # type: Optional[int] self._current_batch_idx = None # type: Optional[int] - self.profiler = None # type: Any - self._mpu = det_ds.make_data_parallel_mpu( self.distributed ) # type: det_ds.ModelParallelUnit @@ -103,48 +131,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._data_repro_checks_disabled = False self._manual_grad_accumulation = False - self._check_experiment_config_optimizations() + self._stop_requested = False self._tbd_writer = None # type: Optional[Any] - self._enable_tensorboard_logging = True + self._enable_tensorboard_logging = enable_tensorboard_logging # Timestamp for batching TensorBoard uploads self._last_tb_reset_ts: Optional[float] = None - def _check_experiment_config_optimizations(self) -> None: - """ - Check if the user specified options in optimizations are incompatible with - DeepSpeedTrial. - """ - optimizations_config = self.env.experiment_config.get_optimizations_config() - self._average_training_metrics = optimizations_config.get("average_training_metrics", False) - - mixed_precision_val = optimizations_config.get("mixed_precision", "O0") - if mixed_precision_val != "O0": - raise det.errors.InvalidExperimentException( - "Mixed precision is specified through the deepspeed config instead of the " - "Determined experiment config.", - ) - aggregation_frequency = optimizations_config.get("aggregation_frequency", 1) - if aggregation_frequency > 1: - raise det.errors.InvalidExperimentException( - "Gradient aggregation is specified through the deepspeed config instead of the " - "Determined experiment config.", - ) - other_optimizations_default_values = { - "average_aggregated_gradients": True, - "gradient_compression": False, - "tensor_fusion_threshold": 64, - "tensor_fusion_cycle_time": 5, - "autotune_tensor_fusion": False, - } - for opt_field, default_value in other_optimizations_default_values.items(): - opt_value = optimizations_config.get(opt_field, default_value) - if opt_value != default_value: - logger.warning( - f"{opt_field}={opt_value} ignored since the setting does not apply " - "to DeepSpeedTrial." - ) - def set_mpu(self, mpu: det_ds.ModelParallelUnit) -> None: """Use a custom model parallel configuration. @@ -166,12 +159,6 @@ def set_mpu(self, mpu: det_ds.ModelParallelUnit) -> None: "Only one MPU can be passed to DeepSpeedTrialContext. " "Please make sure wrap_mpu is only called once in the trial definition." ) - if self.distributed.rank == 0: - if not self._mpu.should_report_metrics and not self._average_training_metrics: - raise det.errors.InvalidExperimentException( - "Please set optimizations.average_training_metrics in the experiment config " - "to true so that metrics will exist on the chief for report to the master." - ) self._called_set_mpu = True self._mpu = mpu @@ -245,16 +232,14 @@ def disable_dataset_reproducibility_checks(self) -> None: def use_pipeline_parallel(self) -> bool: return self._use_pipeline_parallel - @property - def train_micro_batch_size_per_gpu(self) -> int: + def get_train_micro_batch_size_per_gpu(self) -> int: if self._train_micro_batch_size_per_gpu is None: raise det.errors.InvalidExperimentException( "Please call wrap_model_engine before accessing train_micro_batch_size." ) return self._train_micro_batch_size_per_gpu - @property - def num_micro_batches_per_slot(self) -> int: + def get_num_micro_batches_per_slot(self) -> int: if self._num_micro_batches_per_slot is None: raise det.errors.InvalidExperimentException( "Please call wrap_model_engine before accessing num_micro_batches_per_slot." @@ -262,8 +247,7 @@ def num_micro_batches_per_slot(self) -> int: return self._num_micro_batches_per_slot def _init_device(self) -> None: - self.n_gpus = len(self.env.container_gpus) - if not self.n_gpus: + if not self._num_gpus: raise det.errors.InvalidExperimentException("GPUs required for DeepSpeedTrial.") if self.distributed.size > 1: self.device = torch.device("cuda", self.distributed.get_local_rank()) @@ -359,6 +343,38 @@ def set_profiler(self, *args: List[str], **kwargs: Any) -> None: **kwargs, ) + def get_initial_batch(self) -> int: + return self._steps_completed + + def get_data_config(self) -> Dict[str, Any]: + """ + Return the data configuration. + """ + return self.get_experiment_config().get("data", {}) + + def get_experiment_id(self) -> int: + """ + Return the experiment ID of the current trial. + """ + return int(self._core.train._exp_id) + + def get_trial_id(self) -> int: + """ + Return the trial ID of the current trial. + """ + return int(self._core.train._trial_id) + + def get_trial_seed(self) -> int: + if self._trial_seed is None: + raise det.errors.InternalException("Trial seed not set.") + return self._trial_seed + + def get_tensorboard_path(self) -> pathlib.Path: + """ + Get the path where files for consumption by TensorBoard should be written + """ + return self._core.train.get_tensorboard_path() + def get_tensorboard_writer(self) -> Any: """ This function returns an instance of ``torch.utils.tensorboard.SummaryWriter`` @@ -442,3 +458,86 @@ def get_enable_tensorboard_logging(self) -> bool: Return whether automatic tensorboard logging is enabled """ return self._enable_tensorboard_logging + + def get_global_batch_size(self) -> int: + """ + Return the global batch size. + """ + if self._global_batch_size is None: + raise ValueError( + "global_batch_size is undefined in this Trial because hparams was not " + "configured. Please check the init() call to Trainer API." + ) + return self._global_batch_size + + def get_per_slot_batch_size(self) -> int: + """ + Return the per-slot batch size. When a model is trained with a single GPU, this is equal to + the global batch size. When multi-GPU training is used, this is equal to the global batch + size divided by the number of GPUs used to train the model. + """ + if self._per_slot_batch_size is None: + raise ValueError( + "per_slot_batch_size is undefined in this Trial because hparams was not " + "configured. Please check the init() call to Trainer API." + ) + + return self._per_slot_batch_size + + def get_experiment_config(self) -> Dict[str, Any]: + if self._exp_conf is None: + raise ValueError( + "exp_conf is undefined in this Trial. Please check the init() call to Trainer API." + ) + return self._exp_conf + + def get_hparam(self, name: str) -> Any: + """ + Return the current value of the hyperparameter with the given name. + """ + if self._hparams is None: + raise ValueError( + "hparams is undefined in this Trial because hparams was not " + "configured. Please check the init() call to Trainer API." + ) + if name not in self.get_hparams(): + raise ValueError( + "Could not find name '{}' in experiment " + "hyperparameters. Please check your experiment " + "configuration 'hyperparameters' section.".format(name) + ) + if name == "global_batch_size": + logger.warning( + "Please use `context.get_per_slot_batch_size()` and " + "`context.get_global_batch_size()` instead of accessing " + "`global_batch_size` directly." + ) + return self.get_hparams()[name] + + def get_hparams(self) -> Dict[str, Any]: + if self._hparams is None: + raise ValueError( + "hparams is undefined in this Trial because hparams was not " + "configured. Please check the init() call to Trainer API." + ) + return self._hparams + + def get_stop_requested(self) -> bool: + """ + Return whether a trial stoppage has been requested. + """ + return self._stop_requested + + def set_stop_requested(self, stop_requested: bool) -> None: + """ + Set a flag to request a trial stoppage. When this flag is set to True, + we finish the step, checkpoint, then exit. + """ + if not isinstance(stop_requested, bool): + raise AssertionError("stop_requested must be a boolean") + + logger.info( + "A trial stoppage has requested. The trial will be stopped " + "at the end of the current step." + ) + self._stop_requested = stop_requested diff --git a/harness/determined/pytorch/deepspeed/_deepspeed_trial.py b/harness/determined/pytorch/deepspeed/_deepspeed_trial.py index db2a96d1ccd..8c8d3f5d599 100644 --- a/harness/determined/pytorch/deepspeed/_deepspeed_trial.py +++ b/harness/determined/pytorch/deepspeed/_deepspeed_trial.py @@ -1,5 +1,7 @@ import abc import contextlib +import inspect +import json import logging import os import pathlib @@ -7,7 +9,7 @@ import random import time import warnings -from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union, cast +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast import deepspeed import numpy as np @@ -15,7 +17,7 @@ from deepspeed.runtime import dataloader as ds_loader import determined as det -from determined import layers, pytorch, util, workload +from determined import core, pytorch, tensorboard, util from determined.pytorch import deepspeed as det_ds logger = logging.getLogger("determined.pytorch") @@ -30,18 +32,48 @@ def get_length(self: ds_loader.RepeatingLoader) -> int: return len(self.loader) -ds_loader.RepeatingLoader.__len__ = get_length +def dataloader_next(dataloader_iter: Optional[Iterator]) -> Iterator: + if dataloader_iter is None: + return None + while True: + try: + batch = next(dataloader_iter) + except StopIteration: + return + yield batch + +ds_loader.RepeatingLoader.__len__ = get_length -class DeepSpeedTrialController(det.TrialController): - def __init__(self, trial_inst: det.LegacyTrial, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) +class DeepSpeedTrialController: + def __init__( + self, + trial_inst: det.LegacyTrial, + context: det_ds.DeepSpeedTrialContext, + checkpoint_period: pytorch.TrainUnit, + validation_period: pytorch.TrainUnit, + reporting_period: pytorch.TrainUnit, + smaller_is_better: bool, + steps_completed: int, + latest_checkpoint: Optional[str], + local_training: bool, + test_mode: bool, + searcher_metric_name: Optional[str], + checkpoint_policy: str, + step_zero_validation: bool, + max_length: pytorch.TrainUnit, + global_batch_size: Optional[int], + profiling_enabled: Optional[bool], + ) -> None: assert isinstance( trial_inst, DeepSpeedTrial ), "DeepSpeedTrialController needs a DeepSpeedTrial" self.trial = trial_inst - self.context = cast(det_ds.DeepSpeedTrialContext, self.context) + self.context = context + self.core_context = self.context._core + + self.is_chief = self.context.distributed.rank == 0 self.callbacks = self.trial.build_callbacks() for callback in self.callbacks.values(): @@ -59,18 +91,35 @@ def __init__(self, trial_inst: det.LegacyTrial, *args: Any, **kwargs: Any) -> No "This might be caused by not wrapping your model with wrap_model_engine()" ) - self.wlsq = None # type: Optional[layers.WorkloadSequencer] - if self.workloads is None: - self.workloads, self.wlsq = layers.make_compatibility_workloads( - self.context._core, self.env, self.context.models[0].train_batch_size() - ) - - self.steps_completed = self.env.steps_completed + # Don't initialize the state here because it will be invalid until we load a checkpoint. + self.state = None # type: Optional[pytorch._TrialState] + self.start_from_batch = steps_completed + self.val_from_previous_run = self.core_context.train._get_last_validation() + self.step_zero_validation = step_zero_validation + + # Training configs + self.latest_checkpoint = latest_checkpoint + self.test_mode = test_mode + self.searcher_metric_name = searcher_metric_name + self.checkpoint_policy = checkpoint_policy + self.smaller_is_better = smaller_is_better + self.global_batch_size = global_batch_size + self.profiling_enabled = profiling_enabled + + # Training loop variables + self.max_length = max_length + self.checkpoint_period = checkpoint_period + self.validation_period = validation_period + self.reporting_period = reporting_period + + # Training loop state + self.local_training = local_training + self.trial_id = 0 if local_training else self.core_context.train._trial_id @classmethod def pre_execute_hook( cls: Type["DeepSpeedTrialController"], - env: det.EnvContext, + trial_seed: int, distributed_backend: det._DistributedBackend, ) -> None: # We use an environment variable to allow users to enable custom initialization routine for @@ -87,18 +136,19 @@ def pre_execute_hook( # training batch. # TODO (Liam): seed data loading workers so that we can configure different seeds for # data augmentations per slot per worker. - random.seed(env.trial_seed) - np.random.seed(env.trial_seed) - torch.random.manual_seed(env.trial_seed) - - @classmethod - def from_trial( - cls: Type["DeepSpeedTrialController"], *args: Any, **kwargs: Any - ) -> det.TrialController: - return cls(*args, **kwargs) + random.seed(trial_seed) + np.random.seed(trial_seed) + torch.random.manual_seed(trial_seed) + + def _upload_tb_files(self) -> None: + self.context._maybe_reset_tbd_writer() + self.core_context.train.upload_tensorboard_files( + (lambda _: True) if self.is_chief else (lambda p: not p.match("*tfevents*")), + tensorboard.util.get_rank_aware_path, + ) def _set_data_loaders(self) -> None: - skip_batches = self.env.steps_completed + skip_batches = self.start_from_batch # Training and validation data loaders are not built for every slot when model parallelism # is used. @@ -144,14 +194,14 @@ def _set_data_loaders(self) -> None: ) if self.context.use_pipeline_parallel: - if len(self.validation_loader) < self.context.num_micro_batches_per_slot: + if len(self.validation_loader) < self.context.get_num_micro_batches_per_slot(): raise det.errors.InvalidExperimentException( "Number of train micro batches in validation data loader should not be " "less than the number of gradient accumulation steps when using " "pipeline parallelism." ) excluded_micro_batches = ( - len(validation_data) % self.context.num_micro_batches_per_slot + len(validation_data) % self.context.get_num_micro_batches_per_slot() ) if excluded_micro_batches: logger.warning( @@ -182,9 +232,9 @@ def _set_data_loaders(self) -> None: if self.context.use_pipeline_parallel: self.num_validation_batches = ( - self.num_validation_batches // self.context.num_micro_batches_per_slot + self.num_validation_batches // self.context.get_num_micro_batches_per_slot() ) - self.validation_batch_size *= self.context.num_micro_batches_per_slot + self.validation_batch_size *= self.context.get_num_micro_batches_per_slot() # We will do a gather on to get train and val loader lengths and broadcast to all slots. self.context._epoch_len = ( @@ -192,28 +242,34 @@ def _set_data_loaders(self) -> None: ) all_epoch_lens = self.context.distributed.gather(self.context._epoch_len) if self.is_chief: - all_epoch_lens = [le for le in all_epoch_lens if le is not None] + all_epoch_lens = [le for le in all_epoch_lens if le is not None] # type: ignore if min(all_epoch_lens) < max(all_epoch_lens): logger.warning( "Training data loader length inconsistent across ranks. " "Using the minimum for epoch length." ) - self.context._epoch_len = min(all_epoch_lens) // self.context.num_micro_batches_per_slot + self.context._epoch_len = ( + min(all_epoch_lens) // self.context.get_num_micro_batches_per_slot() + ) self.context._epoch_len = self.context.distributed.broadcast(self.context._epoch_len) all_tuples = self.context.distributed.gather( (self.num_validation_batches, self.validation_batch_size) ) if self.is_chief: - all_num_validation_batches, all_validation_batch_size = zip(*all_tuples) - all_num_validation_batches = [le for le in all_num_validation_batches if le is not None] + all_num_validation_batches, all_validation_batch_size = zip(*all_tuples) # type: ignore + all_num_validation_batches = [ + le for le in all_num_validation_batches if le is not None + ] # type: ignore if min(all_num_validation_batches) < max(all_num_validation_batches): logger.warning( "Validation data loader length inconsistent across ranks. " "Using the minimum for validation length." ) self.num_validation_batches = min(all_num_validation_batches) - all_validation_batch_size = [le for le in all_validation_batch_size if le is not None] + all_validation_batch_size = [ + le for le in all_validation_batch_size if le is not None + ] # type: ignore if min(all_validation_batch_size) < max(all_validation_batch_size): logger.warning( "Validation batch size inconsistent across ranks. " @@ -244,7 +300,7 @@ def on_shutdown(callback_name: str, on_trial_shutdown: Callable) -> None: with contextlib.ExitStack() as exit_stack: for callback in self.callbacks.values(): - callback.on_trial_startup(self.steps_completed, self.env.latest_checkpoint) + callback.on_trial_startup(self.start_from_batch, self.latest_checkpoint) exit_stack.enter_context( defer(on_shutdown, callback.__class__.__name__, callback.on_trial_shutdown) ) @@ -264,19 +320,22 @@ def on_shutdown(callback_name: str, on_trial_shutdown: Callable) -> None: ) def cleanup_iterator() -> None: - # Explicitly trigger the training iterator's shutdown (which happens in __del__). + # Explicitly trigger the iterator's shutdown (which happens in __del__). # See the rather long note in pytorch/torch/utils/data/dataloader.py. del self.training_iterator exit_stack.enter_context(defer(cleanup_iterator)) # If a load path is provided load weights and restore the data location. - if self.env.latest_checkpoint is not None: - logger.info(f"Restoring trial from checkpoint {self.env.latest_checkpoint}") + if self.latest_checkpoint is not None: + logger.info(f"Restoring trial from checkpoint {self.latest_checkpoint}") with self.context._core.checkpoint.restore_path( - self.env.latest_checkpoint + self.latest_checkpoint ) as load_path: self._load(load_path) + else: + # If we are not loading, initialize a fresh state. + self.state = pytorch._TrialState(trial_id=self.trial_id) for callback in self.callbacks.values(): callback.on_training_start() @@ -288,172 +347,247 @@ def cleanup_iterator() -> None: self._run() def _run(self) -> None: - assert self.workloads is not None - for w, response_func in self.workloads: - try: - if w.kind == workload.Workload.Kind.RUN_STEP: - action = "training" - metrics = self._train_for_step( - w.step_id, - w.num_batches, - w.total_batches_processed, - ) - response = { - "metrics": metrics, - "stop_requested": self.context.get_stop_requested(), - } # type: workload.Response - metrics = self.context.distributed.broadcast(metrics) - for callback in self.callbacks.values(): - callback.on_training_workload_end( - avg_metrics=metrics["avg_metrics"], - batch_metrics=metrics["batch_metrics"], - ) - elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: - action = "validation" - response = { - "metrics": self._compute_validation_metrics(), - "stop_requested": self.context.get_stop_requested(), - } - elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: - action = "checkpointing" - metadata = { - "steps_completed": self.steps_completed, - "framework": f"torch-{torch.__version__}", - "format": "pickle", - } - with self.context._core.checkpoint.store_path(metadata, shard=True) as ( - path, - storage_id, - ): - self._save(path) - response = {"uuid": storage_id} - for callback in self.callbacks.values(): - callback.on_checkpoint_upload_end(uuid=storage_id) - else: - raise AssertionError("Unexpected workload: {}".format(w.kind)) + assert self.state + + try: + if ( + self.step_zero_validation + and self.val_from_previous_run is None + and self.state.batches_trained == 0 + ): + self._validate() + + self._train( + length=pytorch.Batch(1) if self.test_mode else self.max_length, + train_boundaries=[ + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.TRAIN, + unit=self.max_length, + ), + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.VALIDATE, unit=self.validation_period + ), + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.CHECKPOINT, + unit=self.checkpoint_period, + ), + # Scheduling unit is always configured in batches + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.REPORT, unit=self.reporting_period + ), + ], + ) + except pytorch._ShouldExit as e: + # Checkpoint unsaved work and exit. + if not e.skip_exit_checkpoint and not self._checkpoint_is_current(): + self._checkpoint(already_exiting=True) + + except det.InvalidHP as e: + # Catch InvalidHP to checkpoint before exiting and re-raise for cleanup by core.init() + if not self._checkpoint_is_current(): + self._checkpoint(already_exiting=True) + raise e - except det.InvalidHP as e: - logger.info(f"Invalid hyperparameter exception during {action}: {e}") - response = workload.InvalidHP() - response_func(response) - self.context._maybe_reset_tbd_writer() - self.upload_tb_files() + return - def get_epoch_idx(self, batch_id: int) -> int: + def _get_epoch_idx(self, batch_id: int) -> int: return batch_id // cast(int, self.context._epoch_len) - def _train_for_step( - self, step_id: int, num_batches: int, total_batches_processed: int - ) -> workload.Metrics: - """ - DeepSpeed allows specifying train_batch_size, train_micro_batch_size_per_gpu, and - gradient_accumulation_steps. The three are related as follows: - train_batch_size = train_micro_batch_size * gradient_accumulation_steps. - Hence, if two are specified, the third can be inferred. - - For pipeline parallel training, DeepSpeed will automatically interleave - gradient_accumulation_steps worth of micro batches in one train_batch/eval_batch call. - - With the default DeepSpeed model engine (no pipeline parallel training), the backward - and optimizer step calls track micro batches and will automatically update model weights - and lr scheduler if micro batches % gradient_accumulation_steps == 0. - - Comparing training with and without pipeline parallel is a common goal. Since DeepSpeed's - PipelineEngine trains on a number of micro batches equal to gradient accumulation steps, - we automatically perform gradient accumulation by default when pipeline parallelism is not - enabled. This makes it fair to compare training with and without pipeline parallelism - at a given batch idx. This can be turned off by setting - context.disable_auto_grad_accumulation. - """ - assert step_id > 0, "step_id should be greater than 0" - step_start_time = time.time() - self.context.reset_reducers() + def _train( + self, length: pytorch.TrainUnit, train_boundaries: List[pytorch._TrainBoundary] + ) -> None: + while self._steps_until_complete(length) > 0: + train_boundaries, training_metrics = self._train_with_boundaries(train_boundaries) + + metrics = self._aggregate_training_metrics(training_metrics) + metrics = self.context.distributed.broadcast(metrics) + for callback in self.callbacks.values(): + callback.on_training_workload_end( + avg_metrics=metrics["avg_metrics"], + batch_metrics=metrics["batch_metrics"], + ) + + step_reported = False + + for train_boundary in train_boundaries: + if not train_boundary.limit_reached: + continue + + # Train step limits reached, proceed accordingly. + if train_boundary.step_type == pytorch._TrainBoundaryType.TRAIN: + if self.is_chief and not step_reported: + self._report_training_progress() + elif train_boundary.step_type == pytorch._TrainBoundaryType.REPORT: + if self.is_chief and not step_reported: + self._report_training_progress() + elif train_boundary.step_type == pytorch._TrainBoundaryType.VALIDATE: + if not self._validation_is_current(): + self._validate() + elif train_boundary.step_type == pytorch._TrainBoundaryType.CHECKPOINT: + if not self._checkpoint_is_current(): + self._checkpoint(already_exiting=False) + + # Reset train step limit + train_boundary.limit_reached = False + + # After checkpoint/validation steps, check preemption and upload to tensorboard + if self.context.get_enable_tensorboard_logging(): + self._upload_tb_files() + self._stop_requested() + + # Finished training. Perform final checkpoint/validation if necessary. + if not self._validation_is_current(): + self._validate() + if not self._checkpoint_is_current(): + self._checkpoint(already_exiting=False) + + def _train_with_boundaries( + self, train_boundaries: List[pytorch._TrainBoundary] + ) -> Tuple[List[pytorch._TrainBoundary], List]: + training_metrics = [] + + # Start of train step: tell core API and set model mode + if self.is_chief: + self.core_context.train.set_status("training") - # Set the behavior of certain layers (e.g., dropout) that are different - # between training and inference. for model in self.context.models: model.train() - start = total_batches_processed - end = start + num_batches + self.context.reset_reducers() + + epoch_len = self.context._epoch_len + assert epoch_len, "Training dataloader uninitialized." - per_batch_metrics = [] # type: List[Dict] - num_inputs = 0 + for batch_idx in range(epoch_len): + epoch_idx, batch_in_epoch_idx = divmod(batch_idx, epoch_len) - for batch_idx in range(start, end): - self.steps_completed += 1 - batch_start_time = time.time() + # Set the batch index on the trial context used by step_optimizer. self.context._current_batch_idx = batch_idx - if self.context.is_epoch_start(): - for callback in self.callbacks.values(): - callback.on_training_epoch_start(self.get_epoch_idx(batch_idx)) - # This can be inaccurate if the user's data loader does not return batches with - # the micro batch size. It is also slightly inaccurate if the data loader can return - # partial batches. The same sort of assumptions is made in the DeepSpeed - # model engine's accounting and profiling computations. - batch_inputs = ( - self.context.train_micro_batch_size_per_gpu - * self.context.num_micro_batches_per_slot - ) - num_inputs += batch_inputs - num_train_batch_calls = self.context.num_micro_batches_per_slot - if self.context.use_pipeline_parallel or self.context._manual_grad_accumulation: - num_train_batch_calls = 1 - self.context._loss_ids = {} - for _ in range(num_train_batch_calls): - with contextlib.ExitStack() as exit_stack: - if self.context.profiler: - exit_stack.enter_context(self.context.profiler) - - tr_metrics = self.trial.train_batch( - self.training_iterator, - self.get_epoch_idx(batch_idx), - batch_idx, - ) - if self.context.profiler: - self.context.profiler.step() + # Call epoch start callbacks before training first batch in epoch. + if batch_in_epoch_idx == 0: + self._on_epoch_start(epoch_idx) - if self.context._mpu.should_report_metrics: - if isinstance(tr_metrics, torch.Tensor): - tr_metrics = {"loss": tr_metrics} - if not isinstance(tr_metrics, dict): - raise det.errors.InvalidExperimentException( - "train_batch must return a dictionary " - f"mapping string names to Tensor metrics, got {type(tr_metrics)}", - ) + batch_metrics = self._train_batch(batch_idx=batch_idx, epoch_idx=epoch_idx) + training_metrics.extend(batch_metrics) + self._step_batch() - for name, metric in tr_metrics.items(): - # Convert PyTorch metric values to NumPy, so that - # `det.util.encode_json` handles them properly without - # needing a dependency on PyTorch. - if isinstance(metric, torch.Tensor): - metric = metric.cpu().detach().numpy() - tr_metrics[name] = metric - per_batch_metrics.append(tr_metrics) - # We do a check here to make sure that we do indeed process `num_micro_batches_per_slot` - # micro batches when training a batch for models that do not use pipeline parallelism. - model0 = self.context.models[0] - if not isinstance(model0, deepspeed.PipelineEngine): - assert ( - model0.micro_steps % self.context.num_micro_batches_per_slot == 0 - ), "did not train for gradient accumulation steps" - - batch_dur = time.time() - batch_start_time - samples_per_second = batch_inputs / batch_dur - samples_per_second *= self.context._mpu.data_parallel_world_size - - if self.context.is_epoch_end(): - for callback in self.callbacks.values(): - callback.on_training_epoch_end(self.get_epoch_idx(batch_idx)) + # Batch complete: check if any training periods have been reached and exit if any + for step in train_boundaries: + if isinstance(step.unit, pytorch.Batch): + if step.unit.should_stop(batch_idx + 1): + step.limit_reached = True + + # True epoch based training not supported, detect last batch of epoch to calculate + # fully-trained epochs + if isinstance(step.unit, pytorch.Epoch): + if step.unit.should_stop(epoch_idx + 1): + if batch_in_epoch_idx == epoch_len - 1: + step.limit_reached = True + + # Break early after one batch for test mode + if step.step_type == pytorch._TrainBoundaryType.TRAIN and self.test_mode: + step.limit_reached = True + + # Exit if any train step limits have been reached + if any(step.limit_reached for step in train_boundaries): + return train_boundaries, training_metrics + + # True epoch end + return train_boundaries, training_metrics + + def _train_batch(self, epoch_idx: int, batch_idx: int) -> List[dict]: + num_micro_batches = self.context.get_num_micro_batches_per_slot() + if self.context.use_pipeline_parallel or self.context._manual_grad_accumulation: + num_micro_batches = 1 + + # Reset loss IDs for AMP + self.context._loss_ids = {} + + batch_start_time = time.time() + per_batch_metrics = [] # type: List[Dict] + + for _ in range(num_micro_batches): + with contextlib.ExitStack() as exit_stack: + if self.context.profiler: + exit_stack.enter_context(self.context.profiler) + + training_metrics = self.trial.train_batch( + self.training_iterator, + epoch_idx, + batch_idx, + ) + + if self.context.profiler: + self.context.profiler.step() + + if self.context._mpu.should_report_metrics: + if isinstance(training_metrics, torch.Tensor): + training_metrics = {"loss": training_metrics} + if not isinstance(training_metrics, dict): + raise det.errors.InvalidExperimentException( + "train_batch must return a dictionary " + f"mapping string names to Tensor metrics, got {type(training_metrics)}", + ) + + for name, metric in training_metrics.items(): + # Convert PyTorch metric values to NumPy, so that + # `det.util.encode_json` handles them properly without + # needing a dependency on PyTorch. + if isinstance(metric, torch.Tensor): + metric = metric.cpu().detach().numpy() + training_metrics[name] = metric + per_batch_metrics.append(training_metrics) + # We do a check here to make sure that we do indeed process `num_micro_batches_per_slot` + # micro batches when training a batch for models that do not use pipeline parallelism. + model0 = self.context.models[0] + if not isinstance(model0, deepspeed.PipelineEngine): + assert ( + model0.micro_steps % self.context.get_num_micro_batches_per_slot() == 0 + ), "did not train for gradient accumulation steps" + + batch_dur = time.time() - batch_start_time + batch_inputs = ( + self.context.get_train_micro_batch_size_per_gpu() + * self.context.get_num_micro_batches_per_slot() + ) + samples_per_second = batch_inputs / batch_dur + samples_per_second *= self.context.distributed.size # Aggregate and reduce training metrics from all the training processes. - if self.context.distributed.size > 1 and self.context._average_training_metrics: - per_batch_metrics = pytorch._combine_and_average_training_metrics( + if self.context.distributed.size > 1: + metrics = pytorch._combine_and_average_training_metrics( self.context.distributed, per_batch_metrics ) - num_inputs *= self.context._mpu.data_parallel_world_size - metrics = det.util.make_metrics(num_inputs, per_batch_metrics) + else: + metrics = per_batch_metrics + + return metrics + + def _step_batch(self) -> None: + assert self.state + self.state.batches_trained += 1 + + epoch_len = self.context._epoch_len + assert epoch_len, "Training dataloader not initialized." + + # True epoch-based training is not supported. Epoch end is calculated with batch. + epoch_idx, batch_in_epoch_idx = divmod(self.state.batches_trained - 1, epoch_len) + + if batch_in_epoch_idx == epoch_len - 1: + self._on_epoch_end(epoch_idx) + self.state.epochs_trained += 1 + + def _aggregate_training_metrics(self, training_metrics: List[Dict]) -> Dict: + # Aggregate and reduce training metrics from all the training processes. + if self.context.distributed.size > 1: + batch_metrics = pytorch._combine_and_average_training_metrics( + self.context.distributed, training_metrics + ) + else: + batch_metrics = training_metrics + + metrics = det.util.make_metrics(None, batch_metrics) # Ignore batch_metrics entirely for custom reducers; there's no guarantee that per-batch # metrics are even logical for a custom reducer. @@ -461,27 +595,127 @@ def _train_for_step( pytorch._convert_metrics_to_numpy(self.context.reduce_metrics(for_training=True)) ) - if self.is_chief: - step_duration = time.time() - step_start_time - logger.info(det.util.make_timing_log("trained", step_duration, num_inputs, num_batches)) - - if self.context.get_enable_tensorboard_logging(): - det.pytorch._log_tb_metrics( - self.context.get_tensorboard_writer(), - "train", - self.steps_completed, - metrics["avg_metrics"], - metrics["batch_metrics"], - ) - if not self.is_chief: return {} + # Only report on the chief worker + avg_metrics = metrics.get("avg_metrics", {}) + batch_metrics = metrics.get("batch_metrics", []) + + assert self.state + if self.context.get_enable_tensorboard_logging(): + pytorch._log_tb_metrics( + self.context.get_tensorboard_writer(), + "train", + self.state.batches_trained, + avg_metrics, + batch_metrics, + ) + + self.core_context.train.report_training_metrics( + steps_completed=self.state.batches_trained, + metrics=avg_metrics, + batch_metrics=batch_metrics, + ) return metrics + def _is_best_validation(self, now: float, before: Optional[float]) -> bool: + if before is None: + return True + + return (now < before) if self.smaller_is_better else (now > before) + + def _on_epoch_start(self, epoch_idx: int) -> None: + for callback in self.callbacks.values(): + sig = inspect.signature(callback.on_training_epoch_start) + if sig.parameters: + callback.on_training_epoch_start(epoch_idx) + else: + logger.warning( + "on_training_epoch_start() without parameters is deprecated" + " since 0.17.8. Please add epoch_idx parameter." + ) + callback.on_training_epoch_start() # type: ignore[call-arg] + + def _on_epoch_end(self, epoch_idx: int) -> None: + for callback in self.callbacks.values(): + callback.on_training_epoch_end(epoch_idx) + + def _checkpoint(self, already_exiting: bool) -> None: + if self.is_chief: + self.core_context.train.set_status("checkpointing") + + assert self.state + self.state.last_ckpt = self.state.batches_trained + try: + uuid = "" + metadata = { + "determined_version": det.__version__, + "steps_completed": self.state.batches_trained, + "framework": f"torch-{torch.__version__}", + "format": "pickle", + } + with self.context._core.checkpoint.store_path(metadata, shard=True) as ( + path, + storage_id, + ): + self._save(path) + uuid = storage_id + for callback in self.callbacks.values(): + callback.on_checkpoint_upload_end(uuid=uuid) + except det.InvalidHP: + if not already_exiting: + self.core_context.train.report_early_exit(core.EarlyExitReason.INVALID_HP) + raise pytorch._ShouldExit(skip_exit_checkpoint=True) + raise + + def _stop_requested(self) -> None: + if self.core_context.preempt.should_preempt(): + raise pytorch._ShouldExit() + if self.context.get_stop_requested(): + raise pytorch._ShouldExit() + + def _report_training_progress(self) -> None: + assert self.state + assert isinstance(self.max_length.value, int) + + if isinstance(self.max_length, pytorch.Batch): + progress = self.state.batches_trained / self.max_length.value + elif isinstance(self.max_length, pytorch.Epoch): + progress = self.state.epochs_trained / self.max_length.value + else: + raise ValueError(f"unexpected train unit type {type(self.max_length)}") + + self.core_context.train.report_progress(progress=progress) + + def _checkpoint_is_current(self) -> bool: + assert self.state + # State always persists checkpoint step in batches + return self.state.last_ckpt == self.state.batches_trained + + def _validation_is_current(self) -> bool: + assert self.state + # State persists validation step in batches + return self.state.last_val == self.state.batches_trained + + def _steps_until_complete(self, train_unit: pytorch.TrainUnit) -> int: + assert isinstance(train_unit.value, int), "invalid length type" + assert self.state + if isinstance(train_unit, pytorch.Batch): + return train_unit.value - self.state.batches_trained + elif isinstance(train_unit, pytorch.Epoch): + return train_unit.value - self.state.epochs_trained + else: + raise ValueError(f"Unrecognized train unit {train_unit}") + @torch.no_grad() - def _compute_validation_metrics(self) -> workload.Response: + def _validate(self) -> Dict[str, Any]: + # Report a validation step is starting. + if self.is_chief: + self.core_context.train.set_status("validating") + self.context.reset_reducers() + # Set the behavior of certain layers (e.g., dropout) that are # different between training and inference. for model in self.context.models: @@ -493,57 +727,83 @@ def _compute_validation_metrics(self) -> workload.Response: callback.on_validation_start() num_inputs = 0 - keys = None - batch_metrics = [] + metrics = {} # type: Dict[str, Any] - for callback in self.callbacks.values(): - callback.on_validation_epoch_start() - - validation_iterator = iter(self.validation_loader) if self.validation_loader else None - for idx in range(cast(int, self.num_validation_batches)): - num_inputs += cast(int, self.validation_batch_size) - # Note that when using pipeline parallelism, each call to evaluate_batch will request - # self.context.num_micro_batches_per_slot batches from the validation iterator. - # This is why we set self.num_validation_batches differently for pipeline parallel - # and no pipeline parallel when building the data loaders. - vld_metrics = self.trial.evaluate_batch(validation_iterator, idx) - if self.context._mpu.should_report_metrics: - if not isinstance(vld_metrics, dict): - raise det.errors.InvalidExperimentException( - "evaluate_batch must return a dictionary of string names " - "to Tensor metrics", - ) - # Verify validation metric names are the same across batches. - if keys is None: - keys = vld_metrics.keys() + batches_evaluated = -1 + + if self._evaluate_batch_defined(): + keys = None + batch_metrics = [] + + for callback in self.callbacks.values(): + callback.on_validation_epoch_start() + + validation_iterator = iter(self.validation_loader) if self.validation_loader else None + for idx in range(cast(int, self.num_validation_batches)): + batches_evaluated += 1 + num_inputs += cast(int, self.validation_batch_size) + # Note that when using pipeline parallelism, each call to evaluate_batch will + # request self.context.num_micro_batches_per_slot batches from the validation + # iterator. This is why we set self.num_validation_batches differently for + # pipeline parallel and no pipeline parallel when building the data loaders. + if util.has_param(self.trial.evaluate_batch, "batch_idx", 2): + vld_metrics = self.trial.evaluate_batch(validation_iterator, idx) else: - if keys != vld_metrics.keys(): + vld_metrics = self.trial.evaluate_batch(validation_iterator) # type: ignore + if self.context._mpu.should_report_metrics: + if not isinstance(vld_metrics, dict): raise det.errors.InvalidExperimentException( - "Validation metric names must match across all batches of data.", + "evaluate_batch must return a dictionary " + f"mapping string names to Tensor metrics, got {type(vld_metrics)}", ) - # TODO: For performance perform -> cpu() only at the end of validation. - batch_metrics.append(pytorch._convert_metrics_to_numpy(vld_metrics)) - if self.env.test_mode: - break + for name, metric in vld_metrics.items(): + # Convert PyTorch metric values to NumPy, so that + # `det.util.encode_json` handles them properly without + # needing a dependency on PyTorch. + if isinstance(metric, torch.Tensor): + metric = metric.cpu().detach().numpy() + vld_metrics[name] = metric + # Verify validation metric names are the same across batches. + if keys is None: + keys = vld_metrics.keys() + else: + if keys != vld_metrics.keys(): + raise ValueError( + "Validation metric names must match across all batches of data: " + f"{keys} != {vld_metrics.keys()}.", + ) + batch_metrics.append(pytorch._convert_metrics_to_numpy(vld_metrics)) + if self.test_mode: + break - # keys and list(keys) does not satisfy all cases because it will return dict_keys type if - # keys is an empty dict. this will then break when passed to zmq_broadcast since it does - # not know how to serialize dict_keys type. - all_keys = self.context.distributed.gather(keys if keys is None else list(keys)) - if self.is_chief: - all_keys = [k for k in all_keys if k is not None] - keys = all_keys[0] - keys = self.context.distributed.broadcast(keys) + for callback in self.callbacks.values(): + callback.on_validation_epoch_end(batch_metrics) + + metrics = pytorch._reduce_metrics( + self.context.distributed, + batch_metrics=batch_metrics, + keys=keys, + metrics_reducers=pytorch._prepare_metrics_reducers( + self.trial.evaluation_reducer(), keys=keys + ), + ) - for callback in self.callbacks.values(): - callback.on_validation_epoch_end(batch_metrics) + # Gather a list of per-worker (num_inputs, num_batches) tuples. + input_counts = self.context.distributed.gather((num_inputs, batches_evaluated + 1)) + + else: + assert self._evaluate_full_dataset_defined(), "evaluate_full_dataset not defined." + if self.is_chief: + assert self.validation_loader is not None + metrics = self.trial.evaluate_full_dataset(data_loader=self.validation_loader) + + if not isinstance(metrics, dict): + raise TypeError( + f"eval() must return a dictionary, got {type(metrics).__name__}." + ) + + metrics = pytorch._convert_metrics_to_numpy(metrics) - metrics = pytorch._reduce_metrics( - self.context.distributed, - batch_metrics=batch_metrics, - keys=keys, - metrics_reducers=pytorch._prepare_metrics_reducers(pytorch.Reducer.AVG, keys=keys), - ) metrics.update( pytorch._convert_metrics_to_numpy(self.context.reduce_metrics(for_training=False)) ) @@ -554,51 +814,119 @@ def _compute_validation_metrics(self) -> workload.Response: ): logger.debug( "Broadcasting metrics to all worker processes to execute a " - "validation step end callback" + "validation step end callback." ) metrics = self.context.distributed.broadcast(metrics) for callback in self.callbacks.values(): callback.on_validation_end(metrics) + assert self.state + self.state.last_val = self.state.batches_trained + + # Report metrics. if self.is_chief: - num_inputs *= self.context._mpu.data_parallel_world_size - step_duration = time.time() - step_start_time - logger.info( - det.util.make_timing_log( - "validated", step_duration, num_inputs, cast(int, self.num_validation_batches) + # Skip reporting timings if evaluate_full_dataset() was defined. This is far less + # common than evaluate_batch() and we can't know how the user processed their + # validation data. + if self._evaluate_batch_defined(): + # Reshape and sum. + # TODO: remove the type directive once we upgrade to mypy >= 1.7.0 + inputs_total, batches_total = [sum(n) for n in zip(*input_counts)] # type: ignore + step_duration = time.time() - step_start_time + logger.info( + det.util.make_timing_log( + "validated", step_duration, inputs_total, batches_total + ) ) - ) - if self.context.get_enable_tensorboard_logging(): - det.pytorch._log_tb_metrics( - self.context.get_tensorboard_writer(), "val", self.steps_completed, metrics + pytorch._log_tb_metrics( + self.context.get_tensorboard_writer(), + "val", + self.state.batches_trained, + metrics, ) - if not self.is_chief: - return {} + # Get best validation before reporting metrics. + best_validation_before = self.core_context.train.get_experiment_best_validation() - return {"num_inputs": num_inputs, "validation_metrics": metrics} + # We report "batch" and "epoch" only if these keys are not already reported in user + # metrics. + metrics["batches"] = metrics.get("batches", self.state.batches_trained) + metrics["epochs"] = metrics.get("epochs", self.state.epochs_trained) - def on_validation_step_end(self, metrics: Dict[str, Any]) -> None: - if self.context.get_enable_tensorboard_logging(): - det.pytorch._log_tb_metrics( - self.context.get_tensorboard_writer(), "val", self.steps_completed, metrics + self.core_context.train.report_validation_metrics( + steps_completed=self.state.batches_trained, metrics=metrics ) + should_checkpoint = False + + # Checkpoint according to policy. + if self.is_chief: + if not self._checkpoint_is_current(): + if self.checkpoint_policy == "all": + should_checkpoint = True + elif self.checkpoint_policy == "best": + assert ( + self.searcher_metric_name + ), "checkpoint policy 'best' but searcher metric name not defined" + searcher_metric = self._check_searcher_metric(metrics) + assert searcher_metric is not None + + if self._is_best_validation(now=searcher_metric, before=best_validation_before): + should_checkpoint = True + should_checkpoint = self.context.distributed.broadcast(should_checkpoint) + if should_checkpoint: + self._checkpoint(already_exiting=False) + return metrics + + def _check_searcher_metric(self, val_metrics: Dict) -> Any: + if self.searcher_metric_name not in val_metrics: + raise RuntimeError( + f"Search method is configured to use metric '{self.searcher_metric_name}' but " + f"model definition returned validation metrics {list(val_metrics.keys())}. The " + f"metric used by the search method must be one of the validation " + "metrics returned by the model definition." + ) + + # Check that the searcher metric has a scalar value so that it can be compared for + # search purposes. Other metrics don't have to be scalars. + searcher_metric = val_metrics[self.searcher_metric_name] + if not util.is_numerical_scalar(searcher_metric): + raise RuntimeError( + f"Searcher validation metric '{self.searcher_metric_name}' returned " + f"a non-scalar value: {searcher_metric}." + ) + return searcher_metric + + def _evaluate_batch_defined(self) -> bool: + return util.is_overridden(self.trial.evaluate_batch, DeepSpeedTrial) + + def _evaluate_full_dataset_defined(self) -> bool: + return util.is_overridden(self.trial.evaluate_full_dataset, DeepSpeedTrial) def _load(self, load_path: pathlib.Path) -> None: # Right now we will load all checkpoint shards on each node regardless of which # checkpoints are needed. # TODO (Liam): revisit later to optimize sharded checkpoint loading. + potential_paths = [ + ["state_dict.pth"], + ["determined", "state_dict.pth"], + ["pedl", "state_dict.pth"], + ["checkpoint.pt"], + [f"det_state_dict_rank{self.context.distributed.rank}.pth"], + ] # Load stateful things tracked by Determined on all slots. - ckpt_path = f"det_state_dict_rank{self.context.distributed.rank}.pth" - maybe_ckpt = load_path.joinpath(ckpt_path) + checkpoint: Optional[Dict[str, Any]] = None + for ckpt_path in potential_paths: + maybe_ckpt = load_path.joinpath(*ckpt_path) + if maybe_ckpt.exists(): + checkpoint = torch.load(str(maybe_ckpt), map_location="cpu") + break - if not maybe_ckpt.exists(): + if checkpoint is None or not isinstance(checkpoint, dict): return - checkpoint = torch.load(str(maybe_ckpt), map_location="cpu") if not isinstance(checkpoint, dict): raise det.errors.InvalidExperimentException( f"Expected checkpoint at {maybe_ckpt} to be a dict " @@ -646,27 +974,68 @@ def _load(self, load_path: pathlib.Path) -> None: "callback will be initialized from scratch" ) - # Load workload sequencer state. - wlsq_path = load_path.joinpath("workload_sequencer.pkl") - if self.wlsq is not None and wlsq_path.exists(): - with wlsq_path.open("rb") as f: - self.wlsq.load_state(pickle.load(f)) + save_path = load_path.joinpath("trial_state.pkl") + + if save_path.exists(): + with save_path.open("rb") as f: + self._load_state(pickle.load(f)) + else: + # Support legacy save states. + wlsq_path = load_path.joinpath("workload_sequencer.pkl") + if wlsq_path.exists(): + with wlsq_path.open("rb") as f: + self._load_wlsq_state(pickle.load(f)) + + def _load_state(self, state: Any) -> None: + # Load our state from the checkpoint if we are continuing training after a pause or restart. + # If the trial_id doesn't match our current trial id, we're continuing training a previous + # trial and should start from a fresh state. + if state.get("trial_id") != self.trial_id: + self.state = pytorch._TrialState(trial_id=self.trial_id) + return + + self.state = pytorch._TrialState(**state) + assert self.state + + # Detect the case where the final validation we made was against this exact checkpoint. In + # that case, the master will know about the validation, but it would not appear in the + # checkpoint state. If the validation was before the last checkpoint, the checkpoint state + # is already correct, while any validations after the last checkpoint aren't valid anymore + # and can be safely ignored. + if self.state.batches_trained == self.val_from_previous_run: + self.state.last_val = self.state.batches_trained + + def _load_wlsq_state(self, state: Any) -> None: + if state.get("trial_id") != self.trial_id: + self.state = pytorch._TrialState(trial_id=self.trial_id) + return + + self.state = pytorch._TrialState( + trial_id=state.get("trial_id"), + last_ckpt=state.get("last_ckpt"), + last_val=state.get("last_val"), + step_id=state.get("step_id"), + # steps_completed is a legacy field kept to support loading from older checkpoints. + # checkpoints should only persist batches_trained and epochs_trained + batches_trained=state.get("steps_completed"), + epochs_trained=self._get_epoch_idx(state.get("steps_completed")), + ) + + assert self.state + if self.state.batches_trained == self.val_from_previous_run: + self.state.last_val = self.state.batches_trained def _save(self, path: pathlib.Path) -> None: - if self.context.distributed.local_rank == 0: - path.mkdir(parents=True, exist_ok=True) - _ = self.context.distributed.gather_local(None) # sync + path.mkdir(parents=True, exist_ok=True) if self.is_chief: # We assume these stateful objects should be the same across slots and only have # the chief save them. - util.write_user_code(path, self.env.on_cluster) - - if self.wlsq is not None: - with path.joinpath("workload_sequencer.pkl").open("wb") as f: - pickle.dump(self.wlsq.get_state(), f) + util.write_user_code(path, not self.local_training) + assert self.state + with path.joinpath("trial_state.pkl").open("wb") as f: + pickle.dump(vars(self.state), f) - # Save per rank Determined checkpoint. rng_state = { "cpu_rng_state": torch.random.get_rng_state(), "np_rng_state": np.random.get_state(), @@ -675,22 +1044,21 @@ def _save(self, path: pathlib.Path) -> None: if torch.cuda.device_count(): rng_state["gpu_rng_state"] = torch.cuda.get_rng_state( - self.context.distributed.get_local_rank() + self.context.distributed.local_rank ) - checkpoint = {"rng_state": rng_state} # PyTorch uses optimizer objects that take the model parameters to # optimize on construction, so we store and reload the `state_dict()` # of the model and optimizer explicitly (instead of dumping the entire # objects) to avoid breaking the connection between the model and the # optimizer. - checkpoint["callbacks"] = { - name: callback.state_dict() for name, callback in self.callbacks.items() + checkpoint = { + "callbacks": {name: callback.state_dict() for name, callback in self.callbacks.items()}, + "rng_state": rng_state, } for callback in self.callbacks.values(): callback.on_checkpoint_save_start(checkpoint) - ckpt_name = f"det_state_dict_rank{self.context.distributed.rank}.pth" torch.save(checkpoint, str(path.joinpath(ckpt_name))) @@ -698,6 +1066,22 @@ def _save(self, path: pathlib.Path) -> None: # the save method provided by DeepSpeed. self.trial.save(self.context, path) + with open(path.joinpath("load_data.json"), "w") as f2: + try: + exp_conf = self.context.get_experiment_config() # type: Optional[Dict[str, Any]] + hparams = self.context.get_hparams() # type: Optional[Dict[str, Any]] + except ValueError: + exp_conf = None + hparams = None + + load_data = { + "trial_type": "DeepSpeedTrial", + "experiment_config": exp_conf, + "hparams": hparams, + } + + json.dump(load_data, f2) + for callback in self.callbacks.values(): # TODO(DET-7912): remove on_checkpoint_end once it has been deprecated long enough. callback.on_checkpoint_end(str(path)) @@ -730,8 +1114,8 @@ class DeepSpeedTrial(det.LegacyTrial): """ - trial_controller_class = DeepSpeedTrialController - trial_context_class = det_ds.DeepSpeedTrialContext + trial_controller_class = DeepSpeedTrialController # type: ignore + trial_context_class = det_ds.DeepSpeedTrialContext # type: ignore @abc.abstractmethod def __init__(self, context: det_ds.DeepSpeedTrialContext) -> None: @@ -886,6 +1270,32 @@ def evaluate_batch( """ pass + def evaluate_full_dataset(self, data_loader: torch.utils.data.DataLoader) -> Dict[str, Any]: + """ + Calculate validation metrics on the entire validation dataset and + return them as a dictionary mapping metric names to reduced metric + values (i.e., each returned metric is the average or sum of that metric + across the entire validation set). + + This validation cannot be distributed and is performed on a single + device, even when multiple devices (slots) are used for training. Only + one of :meth:`evaluate_full_dataset` and :meth:`evaluate_batch` should + be overridden by a trial. + + The metrics returned from this function must be JSON-serializable. + + Arguments: + data_loader (torch.utils.data.DataLoader): data loader for evaluating. + """ + pass + + def evaluation_reducer(self) -> Union[pytorch.Reducer, Dict[str, pytorch.Reducer]]: + """ + Return a reducer for all evaluation metrics, or a dict mapping metric + names to individual reducers. Defaults to :obj:`determined.pytorch.Reducer.AVG`. + """ + return pytorch.Reducer.AVG + def save(self, context: det_ds.DeepSpeedTrialContext, path: pathlib.Path) -> None: """ Save is called on every GPU to make sure all checkpoint shards are saved. @@ -924,3 +1334,33 @@ def load( # DeepSpeed does not provide an error message with many assertion errors in the # checkpoint load module. raise AssertionError("Failed to load deepspeed checkpoint.") + + def get_batch_length(self, batch: Any) -> int: + """Count the number of records in a given batch. + + Override this method when you are using custom batch types, as produced + when iterating over the class:`determined.pytorch.DataLoader`. + For example, when using ``pytorch_geometric``: + + .. code-block:: python + + # Extra imports: + from determined.pytorch import DataLoader + from torch_geometric.data.dataloader import Collater + + # Trial methods: + def build_training_data_loader(self): + return DataLoader( + self.train_subset, + batch_size=self.context.get_per_slot_batch_size(), + collate_fn=Collater([], []), + ) + + def get_batch_length(self, batch): + # `batch` is `torch_geometric.data.batch.Batch`. + return batch.num_graphs + + Arguments: + batch (Any): input training or validation data batch object. + """ + return pytorch.data_length(batch) diff --git a/harness/determined/pytorch/deepspeed/_trainer.py b/harness/determined/pytorch/deepspeed/_trainer.py new file mode 100644 index 00000000000..de2514dcf0f --- /dev/null +++ b/harness/determined/pytorch/deepspeed/_trainer.py @@ -0,0 +1,335 @@ +import contextlib +import logging +import os +import random +import sys +import warnings +from typing import Any, Dict, Iterator, Optional + +import deepspeed +import numpy as np +import torch + +import determined as det +from determined import core, gpu, pytorch +from determined.pytorch import deepspeed as det_ds + +logger = logging.getLogger("determined.pytorch.deepspeed") + + +class Trainer: + """ + ``pytorch.deepspeed.Trainer`` is an abstraction on top of a DeepSpeed training loop + that handles many training details under-the-hood, and exposes APIs for configuring + training-related features such as automatic checkpointing, validation, profiling, + metrics reporting, etc. + + ``Trainer`` must be initialized and called from within a + ``pytorch.deepspeed.DeepSpeedTrialContext``. + """ + + def __init__(self, trial: det_ds.DeepSpeedTrial, context: det_ds.DeepSpeedTrialContext): + self._trial = trial + self._context = context + self._core = self._context._core + self._info = det.get_cluster_info() + self._local_training = self._info is None or self._info.task_type != "TRIAL" + + def fit( + self, + checkpoint_period: Optional[pytorch.TrainUnit] = None, + validation_period: Optional[pytorch.TrainUnit] = None, + max_length: Optional[pytorch.TrainUnit] = None, + reporting_period: pytorch.TrainUnit = pytorch.Batch(100), # noqa: B008 + checkpoint_policy: str = "best", + latest_checkpoint: Optional[str] = None, + step_zero_validation: bool = False, + test_mode: bool = False, + profiling_enabled: bool = False, + ) -> None: + """ + ``fit()`` trains a ``DeepSpeedTrial`` configured from the ``Trainer`` and handles + checkpointing and validation steps, and metrics reporting. + + Arguments: + checkpoint_period: The number of steps to train for before checkpointing. This is + a ``TrainUnit`` type (``Batch`` or ``Epoch``) which can take an ``int`` or + instance of ``collections.abc.Container`` (list, tuple, etc.). For example, + ``Batch(100)`` would checkpoint every 100 batches, while ``Batch([5, 30, 45])`` + would checkpoint after every 5th, 30th, and 45th batch. + validation_period: The number of steps to train for before validating. This is a + ``TrainUnit`` type (``Batch`` or ``Epoch``) which can take an ``int`` or instance + of ``collections.abc.Container`` (list, tuple, etc.). For example, ``Batch(100)`` + would validate every 100 batches, while ``Batch([5, 30, 45])`` would validate + after every 5th, 30th, and 45th batch. + max_length: The maximum number of steps to train for. This is a ``TrainUnit`` type + (``Batch`` or ``Epoch``) which takes an ``int``. For example, ``Epoch(1)`` would + train for a maximum length of one epoch. + .. note:: + If using an ASHA searcher, this value should match the searcher config values in + the experiment config (i.e. ``Epoch(1)`` = `max_time: 1` and `time_metric: + "epochs"`). + + reporting_period: The number of steps to train for before reporting metrics and + searcher progress. For local training mode, metrics are printed to stdout. This + is a ``TrainUnit`` type (``Batch`` or ``Epoch``) which can take an ``int`` or + instance of ``collections.abc.Container`` (list, tuple, etc.). For example, + ``Batch(100)`` would report every 100 batches, while ``Batch([5, 30, 45])`` would + report after every 5th, 30th, and 45th batch. + checkpoint_policy: Controls how Determined performs checkpoints after validation + operations, if at all. Should be set to one of the following values: + + best (default): A checkpoint will be taken after every validation operation + that performs better than all previous validations for this experiment. + Validation metrics are compared according to the ``metric`` and + ``smaller_is_better`` fields in the searcher configuration. This option + is only supported for on-cluster training. + all: A checkpoint will be taken after every validation, no matter the + validation performance. + none: A checkpoint will never be taken due to a validation. However, + even with this policy selected, checkpoints are still expected to be taken + after the trial is finished training, due to cluster scheduling decisions, + before search method decisions, or due to ``min_checkpoint_period``. + latest_checkpoint: Configures the checkpoint used to start or continue training. + This value should be set to ``det.get_cluster_info().latest_checkpoint`` for + standard continue training functionality. + step_zero_validation: Configures whether to perform an initial validation before + training. Defaults to false. + test_mode: Runs a minimal loop of training for testing and debugging purposes. Will + train and validate one batch. Defaults to false. + profiling_enabled: Enables system metric profiling functionality for on-cluster + training. Defaults to false. + """ + # Set defaults. + if checkpoint_period is None: + checkpoint_period = pytorch.Batch(sys.maxsize) + + if validation_period is None: + validation_period = pytorch.Batch(sys.maxsize) + + if self._local_training: + if checkpoint_policy == "best": + logger.warning( + "checkpoint_policy='best' is not supported in local training mode. " + "Falling back to 'all'." + ) + checkpoint_policy = "all" + if max_length is None: + raise ValueError("max_length must be defined in local training mode.") + + if not isinstance(max_length, (pytorch.Batch, pytorch.Epoch)) or not isinstance( + max_length.value, int + ): + raise TypeError( + "max_length must either be a det.pytorch.Batch(int) or det.pytorch.Epoch(int) " + "type" + ) + + if profiling_enabled: + logger.warning("Profiling is not supported in local training mode.") + + smaller_is_better = True + searcher_metric_name = None + steps_completed = 0 + global_batch_size = None + else: + if test_mode: + raise ValueError("test_mode is only supported in local training mode.") + + assert self._info, "Unable to detect cluster info." + if latest_checkpoint is None and self._info.latest_checkpoint is not None: + logger.warning( + "latest_checkpoint has not been configured. Pause/resume training will not " + "be able to continue from latest checkpoint. Did you mean to set " + "`fit(latest_checkpoint=info.latest_checkpoint)'?" + ) + + smaller_is_better = bool(self._info.trial._config["searcher"]["smaller_is_better"]) + searcher_metric_name = self._info.trial._config["searcher"]["metric"] + steps_completed = int(self._info.trial._steps_completed) + global_batch_size = self._info.trial.hparams.get("global_batch_size", None) + if global_batch_size: + global_batch_size = int(global_batch_size) + + # Backwards compatibility: try to parse legacy `searcher.max_length` if `max_length` + # isn't passed in. + if max_length is None: + max_length_val = core._parse_searcher_max_length(self._info.trial._config) + if max_length_val: + warnings.warn( + "Configuring `max_length` from the `searcher.max_length` experiment " + "config, which was deprecated in XXYYZZ and will be removed in a future " + "release. Please set `fit(max_length=X)` with your desired training length " + "directly.", + FutureWarning, + stacklevel=2, + ) + max_length_unit = core._parse_searcher_units(self._info.trial._config) + max_length = pytorch.TrainUnit._from_searcher_unit( + max_length_val, max_length_unit, global_batch_size + ) + + # If we couldn't parse the legacy `searcher.max_length`, raise an error. + if not max_length: + raise ValueError( + "`fit(max_length=X)` must be set with your desired training length." + ) + if not isinstance(max_length, (pytorch.Batch, pytorch.Epoch)) or not isinstance( + max_length.value, int + ): + raise TypeError( + "max_length must either be a det.pytorch.Batch(int) or det.pytorch.Epoch(int) " + "type." + ) + + _check_searcher_length(exp_conf=self._info.trial._config, max_length=max_length) + + trial_controller = det_ds.DeepSpeedTrialController( + trial_inst=self._trial, + context=self._context, + checkpoint_period=checkpoint_period, + validation_period=validation_period, + smaller_is_better=smaller_is_better, + steps_completed=steps_completed, + latest_checkpoint=latest_checkpoint, + local_training=self._local_training, + test_mode=test_mode, + reporting_period=reporting_period, + searcher_metric_name=searcher_metric_name, + checkpoint_policy=checkpoint_policy, + step_zero_validation=step_zero_validation, + max_length=max_length, + global_batch_size=global_batch_size, + profiling_enabled=profiling_enabled, + ) + + trial_controller.run() + + +def _check_searcher_length( + exp_conf: Dict[str, Any], + max_length: pytorch.TrainUnit, +) -> None: + """ + Certain searchers (ASHA and Adaptive ASHA) require configuring the maximum training length in + the experiment config. We check that the `max_length` passed to `fit()` matches the experiment + config and log warnings if it doesn't. + """ + time_metric = exp_conf["searcher"].get("time_metric") + if time_metric is not None: + max_time = exp_conf["searcher"].get("max_time") + assert max_time, "`searcher.max_time` not configured" + if time_metric == "batches": + if not isinstance(max_length, pytorch.Batch) or max_length.value != max_time: + logger.warning( + f"`max_length` passed into `fit()` method ({max_length}) does not match " + f"`searcher.max_time` and `searcher.time_metric` from the experiment config " + f"(Batch(value={max_time})). This may result in unexpected hyperparameter " + f"search behavior." + ) + elif time_metric == "epochs": + if not isinstance(max_length, pytorch.Epoch) or max_length.value != max_time: + logger.warning( + f"`max_length` passed into `fit()` method ({max_length}) does not match " + f"`searcher.max_time` and `searcher.time_metric` from the experiment config " + f"(Epoch(value={max_time})). This may result in unexpected hyperparameter " + f"search behavior." + ) + else: + logger.warning( + "`searcher.time_metric` must be either 'batches' or 'epochs' " + f"for training with PyTorchTrials, but got {time_metric}. " + f"Training will proceed with {max_length} but may result in unexpected behavior." + ) + + +def _initialize_distributed_backend() -> Optional[core.DistributedContext]: + info = det.get_cluster_info() + distributed_backend = det._DistributedBackend() + + if distributed_backend.use_deepspeed(): + # We use an environment variable to allow users to enable custom initialization routine for + # distributed training since the pre_execute_hook runs before trial initialization. + manual_dist_init = os.environ.get("DET_MANUAL_INIT_DISTRIBUTED") + if not manual_dist_init: + deepspeed.init_distributed(auto_mpi_discovery=False) + return core.DistributedContext.from_deepspeed() + elif info and (len(info.container_addrs) > 1 or len(info.slot_ids) > 1): + raise ValueError( + "In multi-slot managed cluster training, you must wrap your training script with a " + "distributed launch layer such as determined.launch.deepspeed." + ) + return None + + +def _set_random_seeds(seed: int) -> None: + # Set identical random seeds on all training processes. + # When doing distributed training, each worker will start at a unique + # offset in the dataset, ensuring that it is processing a unique + # training batch. + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + +@contextlib.contextmanager +def init( + *, + hparams: Optional[Dict] = None, + exp_conf: Optional[Dict[str, Any]] = None, + distributed: Optional[core.DistributedContext] = None, + enable_tensorboard_logging: bool = True, +) -> Iterator[det_ds.DeepSpeedTrialContext]: + """ + Creates a DeepSpeedTrialContext for use with a DeepSpeedTrial. All trainer.* calls + must be within the scope of this context because there are resources started in + __enter__ that must be cleaned up in __exit__. + + Arguments: + hparams: (Optional) instance of hyperparameters for the trial + exp_conf: (Optional) for local-training mode. If unset, calling + context.get_experiment_config() will fail. + distributed: (Optional) custom distributed training configuration + enable_tensorboard_logging: Configures if upload to tensorboard is enabled + """ + cluster_info = det.get_cluster_info() + local_training = cluster_info is None or cluster_info.task_type != "TRIAL" + + # Pre-execute steps: initialize distributed backend and random seeds. + distributed_context = distributed + + if not local_training: + distributed_context = _initialize_distributed_backend() + + # Initialize default values. + if local_training: + trial_seed = None + steps_completed = 0 + num_gpus = len(gpu.get_gpu_uuids()) + else: + assert cluster_info, "Unable to detect cluster info" + + trial_seed = cluster_info.trial.trial_seed + exp_conf = cluster_info.trial._config + steps_completed = cluster_info.trial._steps_completed + num_gpus = len(cluster_info.gpu_uuids) + + _set_random_seeds(trial_seed) + + with core.init( + distributed=distributed_context, + preempt_mode=core.PreemptMode.WorkersAskChief, + tensorboard_mode=core.TensorboardMode.MANUAL, + ) as core_context: + context = det_ds.DeepSpeedTrialContext( + core_context=core_context, + trial_seed=trial_seed, + hparams=hparams, + slots_per_trial=core_context.distributed.get_size(), + num_gpus=num_gpus, + exp_conf=exp_conf, + steps_completed=steps_completed, + enable_tensorboard_logging=enable_tensorboard_logging, + ) + + yield context diff --git a/harness/tests/experiment/fixtures/deepspeed_linear_model.py b/harness/tests/experiment/fixtures/deepspeed_linear_model.py index 900236c6cb1..3fd06f08cdf 100644 --- a/harness/tests/experiment/fixtures/deepspeed_linear_model.py +++ b/harness/tests/experiment/fixtures/deepspeed_linear_model.py @@ -12,6 +12,43 @@ from determined import pytorch +class MetricsCallbacks(pytorch.PyTorchCallback): + def __init__(self, trial) -> None: + self.trial = trial + super().__init__() + + def on_validation_end(self, metrics: Dict) -> None: + assert "loss" in metrics.keys() + + def on_checkpoint_upload_end(self, uuid: str) -> None: + self.trial.checkpoint_uuid = uuid + + def on_checkpoint_load_start(self, checkpoint: Optional[Dict]): + self.trial.checkpoint_found = checkpoint is not None + + +class ReproducibilityCallbacks(pytorch.PyTorchCallback): + def __init__(self, trial) -> None: + self.trial = trial + super().__init__() + + def on_validation_end(self, metrics: Dict) -> None: + self.trial.val_metrics.append(metrics) + + def on_training_workload_end(self, avg_metrics, batch_metrics): + self.trial.avg_metrics.append(avg_metrics) + self.trial.batch_metrics.append(batch_metrics) + + +class TwoEngineMetricsCallbacks(pytorch.PyTorchCallback): + def __init__(self) -> None: + super().__init__() + + def on_validation_end(self, metrics: Dict) -> None: + assert "loss1" in metrics.keys() + assert "loss2" in metrics.keys() + + class LinearDataset(torch.utils.data.Dataset): def __init__(self, a: int, b: int, num_samples: int): self.a = a @@ -31,9 +68,11 @@ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: class LinearDeepSpeedTrial(det_ds.DeepSpeedTrial): _searcher_metric = "loss" - def __init__(self, context: det_ds.DeepSpeedTrialContext): + def __init__(self, context: det_ds.DeepSpeedTrialContext, hparams: Dict): self.context = context - self.hparams = attrdict.AttrDict(context.get_hparams()) + self.hparams = attrdict.AttrDict(hparams) + self.checkpoint_uuid = None + self.checkpoint_found = None if ( self.hparams.test_manual_init_distributed or self.hparams.test_fail_manual_init_distributed @@ -64,6 +103,9 @@ def __init__(self, context: det_ds.DeepSpeedTrialContext): if self.hparams.test_custom_reducer: self.reducer = self.context.wrap_reducer(lambda x: np.mean(x) * 2, name="loss_2x") + def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: + return {"my_callbacks": MetricsCallbacks(trial=self)} + def build_training_data_loader(self) -> Union[pytorch.DataLoader, torch.utils.data.DataLoader]: dataset = LinearDataset(1, 1, self.ds_config.train_batch_size * 2) dataloader = pytorch.DataLoader( @@ -158,8 +200,8 @@ def evaluate_batch( class LinearCallbackTrial(LinearDeepSpeedTrial): - def __init__(self, context: det_ds.DeepSpeedTrialContext): - super().__init__(context) + def __init__(self, context: det_ds.DeepSpeedTrialContext, hparams: Dict): + super().__init__(context, hparams) self.counter = counter.Counter() def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: @@ -167,9 +209,9 @@ def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: class LinearTwoEngineTrial(LinearDeepSpeedTrial): - def __init__(self, context: det_ds.DeepSpeedTrialContext): + def __init__(self, context: det_ds.DeepSpeedTrialContext, hparams: Dict): self.context = context - self.hparams = attrdict.AttrDict(context.get_hparams()) + self.hparams = attrdict.AttrDict(hparams) self.ds_config = attrdict.AttrDict(self.hparams.deepspeed_config) model1 = torch.nn.Linear(1, 1) model2 = torch.nn.Linear(1, 1) @@ -183,6 +225,9 @@ def __init__(self, context: det_ds.DeepSpeedTrialContext): self.model1 = self.context.wrap_model_engine(self.model1) self.model2 = self.context.wrap_model_engine(self.model2) + def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: + return {"my_callbacks": TwoEngineMetricsCallbacks()} + def train_batch( self, dataloader_iter: Optional[Iterator[pytorch.TorchData]], @@ -214,10 +259,13 @@ def take_step(model): class LinearPipelineEngineTrial(LinearDeepSpeedTrial): - def __init__(self, context: det_ds.DeepSpeedTrialContext): + def __init__(self, context: det_ds.DeepSpeedTrialContext, hparams: Dict): self.context = context - self.hparams = attrdict.AttrDict(context.get_hparams()) + self.hparams = attrdict.AttrDict(hparams) self.ds_config = attrdict.AttrDict(self.hparams.deepspeed_config) + self.avg_metrics = [] + self.batch_metrics = [] + self.val_metrics = [] model = torch.nn.Linear(1, 1) model = deepspeed.PipelineModule( layers=[model], @@ -232,6 +280,9 @@ def __init__(self, context: det_ds.DeepSpeedTrialContext): self.model = self.context.wrap_model_engine(self.model) self.context.set_mpu(det_ds.make_deepspeed_mpu(self.model.mpu)) + def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: + return {"my_callbacks": ReproducibilityCallbacks(trial=self)} + def train_batch( self, dataloader_iter: Optional[Iterator[pytorch.TorchData]], diff --git a/harness/tests/experiment/integrations/test_deepspeed_trial.py b/harness/tests/experiment/integrations/test_deepspeed_trial.py index 06a2e4d6f57..15d5155ce08 100644 --- a/harness/tests/experiment/integrations/test_deepspeed_trial.py +++ b/harness/tests/experiment/integrations/test_deepspeed_trial.py @@ -4,17 +4,18 @@ import os import pathlib import shutil -from typing import Any, Dict, Iterator, Optional +from typing import Iterator +import appdirs import pytest import torch from deepspeed.runtime import config_utils import determined -import determined.pytorch.deepspeed as det_deepspeed -from determined import workload -from tests.experiment import utils # noqa: I100 -from tests.experiment.fixtures import deepspeed_linear_model +import determined.pytorch.deepspeed as det_ds +from determined import pytorch # noqa: I2041 +from determined.pytorch.deepspeed import _trainer # noqa: I2041 +from tests.experiment.fixtures import deepspeed_linear_model # noqa: I2041 ds_config_path = str( pathlib.Path(__file__).resolve().parent.parent.joinpath("fixtures/ds_config.json") @@ -82,521 +83,229 @@ def test_fail_manual_init_distributed(self, manual_init_distributed: None): updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_fail_manual_init_distributed"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - with pytest.raises(AssertionError, match=r"Distributed backend is not initialized. .*"): - _ = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) def test_manual_init_distributed(self, manual_init_distributed: None): updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_manual_init_distributed"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) - _ = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) assert torch.distributed.is_initialized() def test_linear_model(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_manual_grad_acc_metrics(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_manual_grad_acc"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send(steps=10, validation_freq=10, train_batch_calls=1) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) def test_fail_manual_grad_acc_metrics(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_fail_manual_grad_acc"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send(steps=10, validation_freq=10, train_batch_calls=1) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - with pytest.raises(AssertionError, match="did not train for gradient accumulation steps"): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) def test_custom_dataloader(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_manual_dataloader"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_fail_dataset_repro_check(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_fail_dataset_repro_check"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - with pytest.raises(RuntimeError, match=r".* reproducibility .* disable this check .*"): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) def test_invalid_valid_dataset(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( determined.errors.InvalidExperimentException, match=r".* train micro batches .* should not be less than .*", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.InvalidValidDatasetTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.InvalidValidDatasetTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_invalid_train_metric(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( determined.errors.InvalidExperimentException, match=r"train_batch() must return a dictionary .*", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.InvalidTrainMetricTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.InvalidTrainMetricTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_invalid_valid_metric(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( determined.errors.InvalidExperimentException, match=r"evaluate_batch must return a dictionary .*", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.InvalidValidMetricTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.InvalidValidMetricTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_differing_valid_metric_keys(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( - determined.errors.InvalidExperimentException, - match=r".* metric names must match across all batches .*", + ValueError, + match=r"Validation metric names must match across all batches of data: .*", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.DifferingValidMetricKeyTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.DifferingValidMetricKeyTrial( + train_context, self.hparams + ) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_fail_multiple_set_mpu(self): - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=1, - validation_freq=1, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( - determined.errors.InvalidExperimentException, match=r"Only one MPU can be passed .*" + determined.errors.InvalidExperimentException, + match=r"Only one MPU can be passed to DeepSpeedTrialContext.", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.context.set_mpu( - det_deepspeed.make_data_parallel_mpu(controller.context.distributed) - ) - controller.context.set_mpu( - det_deepspeed.make_data_parallel_mpu(controller.context.distributed) - ) + with det_ds.init() as train_context: + _ = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + train_context.set_mpu(det_ds.make_data_parallel_mpu(train_context.distributed)) + train_context.set_mpu(det_ds.make_data_parallel_mpu(train_context.distributed)) def test_custom_reducer(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_custom_reducer"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_linear_non_scalar_metrics(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["return_non_scalar_metrics"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_linear_pipeline_model(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send(steps=1, validation_freq=1, train_batch_calls=1) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearPipelineEngineTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearPipelineEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_two_model_engines(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=1, - validation_freq=1, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss1" in metrics - assert "loss2" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearTwoEngineTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() - - @pytest.mark.skipif(not check_shm_size(), reason="insufficient shm size") - def test_checkpointing_and_restoring(self, tmp_path: pathlib.Path) -> None: - def make_trial_controller_fn( - workloads: workload.Stream, - checkpoint_dir: Optional[str] = None, - latest_checkpoint: Optional[Dict[str, Any]] = None, - steps_completed: int = 0, - ) -> determined.TrialController: - return utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearPipelineEngineTrial, - hparams=self.hparams, - workloads=workloads, - trial_seed=self.trial_seed, - checkpoint_dir=checkpoint_dir, - latest_checkpoint=latest_checkpoint, - steps_completed=steps_completed, - expose_gpus=True, - ) - - utils.checkpointing_and_restoring_test(make_trial_controller_fn, tmp_path) - - def test_restore_invalid_checkpoint(self, tmp_path: pathlib.Path) -> None: - # Build, train, and save a checkpoint with the normal hyperparameters. - checkpoint_dir = str(tmp_path.joinpath("checkpoint")) - latest_checkpoint = None - steps_completed = 0 - - def make_workloads_1() -> workload.Stream: - trainer = utils.TrainAndValidate() - yield from trainer.send( - steps=1, - validation_freq=1, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - interceptor = workload.WorkloadResponseInterceptor() - yield from interceptor.send(workload.checkpoint_workload()) - nonlocal latest_checkpoint, steps_completed - latest_checkpoint = interceptor.metrics_result()["uuid"] - steps_completed = trainer.get_steps_completed() - - controller1 = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=self.hparams, - workloads=make_workloads_1(), - trial_seed=self.trial_seed, - checkpoint_dir=checkpoint_dir, - expose_gpus=True, - ) - controller1.run() - - # Verify that an invalid architecture fails to load from the checkpoint. - def make_workloads_2() -> workload.Stream: - trainer = utils.TrainAndValidate() - yield from trainer.send( - steps=1, - validation_freq=1, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearTwoEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) + + def test_checkpointing_and_restoring(self) -> None: + with det_ds.init() as train_context: + trial1 = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial1, train_context) + assert trial1.checkpoint_uuid is None + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) + with det_ds.init() as train_context: + trial2 = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial2, train_context) + assert trial1.checkpoint_uuid is not None + trainer.fit( + validation_period=pytorch.Batch(16), + max_length=pytorch.Batch(16), + latest_checkpoint=os.path.join( + appdirs.user_data_dir("determined"), trial1.checkpoint_uuid + ), ) - with pytest.raises(AssertionError, match="Failed to load deepspeed checkpoint."): - controller2 = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearTwoEngineTrial, - hparams=self.hparams, - workloads=make_workloads_2(), - trial_seed=self.trial_seed, - checkpoint_dir=checkpoint_dir, - latest_checkpoint=latest_checkpoint, - steps_completed=steps_completed, - expose_gpus=True, - ) - controller2.run() - - @pytest.mark.skipif(not check_shm_size(), reason="insufficient shm size") + def test_restore_invalid_checkpoint(self) -> None: + with det_ds.init() as train_context: + trial1 = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial1, train_context) + assert trial1.checkpoint_uuid is None + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) + + with det_ds.init() as train_context: + trial2 = deepspeed_linear_model.LinearTwoEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial2, train_context) + assert trial1.checkpoint_uuid is not None + with pytest.raises(AssertionError, match="Failed to load deepspeed checkpoint."): + trainer.fit( + validation_period=pytorch.Batch(16), + max_length=pytorch.Batch(16), + latest_checkpoint=os.path.join( + appdirs.user_data_dir("determined"), trial1.checkpoint_uuid + ), + ) + + # TODO: Remove this particular skip after CI is updated (INFENG-659) + @pytest.mark.skipif(shutil.disk_usage("/dev/shm")[0] < 10**8, reason="insufficient shm size") def test_reproducibility(self) -> None: - def controller_fn(workloads: workload.Stream) -> determined.TrialController: - return utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearPipelineEngineTrial, - hparams=self.hparams, - workloads=workloads, - trial_seed=self.trial_seed, - expose_gpus=True, - ) - - utils.reproducibility_test(controller_fn, steps=1000, validation_freq=100) - - @pytest.mark.skipif(not check_shm_size(), reason="insufficient shm size") - def test_callbacks(self, tmp_path: pathlib.Path) -> None: - checkpoint_dir = tmp_path.joinpath("checkpoint") - latest_checkpoint = None - steps_completed = 0 - - controller = None - - def make_workloads1() -> workload.Stream: - nonlocal controller - assert controller.trial.counter.trial_startups == 1 - - yield workload.train_workload(1, 1, 0, 4), workload.ignore_workload_response - assert controller is not None, "controller was never set!" - assert controller.trial.counter.__dict__ == { - "trial_startups": 1, - "validation_steps_started": 0, - "validation_steps_ended": 0, - "checkpoints_written": 0, - "checkpoints_uploaded": 0, - "training_started_times": 1, - "training_epochs_started": 2, - "training_epochs_ended": 2, - "training_workloads_ended": 1, - "trial_shutdowns": 0, - } - - yield workload.validation_workload(), workload.ignore_workload_response - assert controller.trial.counter.__dict__ == { - "trial_startups": 1, - "validation_steps_started": 1, - "validation_steps_ended": 1, - "checkpoints_written": 0, - "checkpoints_uploaded": 0, - "training_started_times": 1, - "training_epochs_started": 2, - "training_epochs_ended": 2, - "training_workloads_ended": 1, - "trial_shutdowns": 0, - } - - interceptor = workload.WorkloadResponseInterceptor() - yield from interceptor.send(workload.checkpoint_workload()) - nonlocal latest_checkpoint, steps_completed - latest_checkpoint = interceptor.metrics_result()["uuid"] - steps_completed = 1 - assert controller.trial.counter.__dict__ == { + with det_ds.init() as train_context: + _trainer._set_random_seeds(self.trial_seed) + train_context._trial_seed = self.trial_seed + trial1 = deepspeed_linear_model.LinearPipelineEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial1, train_context) + trainer.fit(validation_period=pytorch.Batch(100), max_length=pytorch.Batch(1000)) + + with det_ds.init() as train_context: + _trainer._set_random_seeds(self.trial_seed) + train_context._trial_seed = self.trial_seed + trial2 = deepspeed_linear_model.LinearPipelineEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial2, train_context) + trainer.fit(validation_period=pytorch.Batch(100), max_length=pytorch.Batch(1000)) + + assert len(trial1.avg_metrics) == len(trial2.avg_metrics) + for A, B in zip(trial1.avg_metrics, trial2.avg_metrics): + assert A.keys() == B.keys() + for key in A.keys(): + assert abs(A[key] - B[key]) < 10e-7 + + assert len(trial1.batch_metrics) == len(trial2.batch_metrics) + for batch_idx in range(len(trial1.batch_metrics)): + for A, B in zip(trial1.batch_metrics[batch_idx], trial2.batch_metrics[batch_idx]): + assert A.keys() == B.keys() + for key in A.keys(): + assert abs(A[key] - B[key]) < 10e-7 + + assert len(trial1.val_metrics) == len(trial2.val_metrics) + for A, B in zip(trial1.val_metrics, trial2.val_metrics): + assert A.keys() == B.keys() + for key in A.keys(): + assert abs(A[key] - B[key]) < 10e-7 + + def test_callbacks(self) -> None: + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearCallbackTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Epoch(2)) + assert trial.counter.__dict__ == { "trial_startups": 1, "validation_steps_started": 1, "validation_steps_ended": 1, @@ -605,51 +314,10 @@ def make_workloads1() -> workload.Stream: "training_started_times": 1, "training_epochs_started": 2, "training_epochs_ended": 2, - "training_workloads_ended": 1, - "trial_shutdowns": 0, + "training_workloads_ended": 2, + "trial_shutdowns": 1, } - hparams1 = dict(self.hparams) - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearCallbackTrial, - hparams=hparams1, - workloads=make_workloads1(), - checkpoint_dir=str(checkpoint_dir), - expose_gpus=True, - ) - controller.run() - assert controller.trial.counter.trial_shutdowns == 1 - - # Verify the checkpoint loading callback works. - def make_workloads2() -> workload.Stream: - yield workload.train_workload(1, 1, 0, 2), workload.ignore_workload_response - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearCallbackTrial, - hparams=self.hparams, - workloads=make_workloads2(), - checkpoint_dir=str(checkpoint_dir), - latest_checkpoint=latest_checkpoint, - steps_completed=steps_completed, - expose_gpus=True, - ) - controller.run() - assert controller.trial.counter.__dict__ == { - # Note: trial_startups will get reset by the loading logic. - "trial_startups": 1, - "validation_steps_started": 1, - "validation_steps_ended": 1, - # Note: checkpoints_written, checkpoints_uploaded, and trial_shutdowns, cannot be - # persisted, as they are all updated after checkpointing. - "checkpoints_written": 0, - "checkpoints_uploaded": 0, - "training_started_times": 2, - "training_epochs_started": 3, - "training_epochs_ended": 3, - "training_workloads_ended": 2, - "trial_shutdowns": 1, - } - @pytest.mark.deepspeed def test_overwrite_deepspeed_config() -> None: @@ -661,16 +329,16 @@ def test_overwrite_deepspeed_config() -> None: expected_config = copy.deepcopy(deepspeed_config) expected_config["train_micro_batch_size_per_gpu"] = 2 expected_config["optimizer"]["params"]["lr"] = 0.001 - result = det_deepspeed.overwrite_deepspeed_config(base_ds_config, source_ds_config) + result = det_ds.overwrite_deepspeed_config(base_ds_config, source_ds_config) assert result == expected_config # Test load base deepspeed config from json file. base_ds_config = str( pathlib.Path(__file__).resolve().parent.parent.joinpath("fixtures/ds_config.json") ) - result = det_deepspeed.overwrite_deepspeed_config(base_ds_config, source_ds_config) + result = det_ds.overwrite_deepspeed_config(base_ds_config, source_ds_config) assert result == expected_config # Test fail invalid base_ds_config argument. with pytest.raises(TypeError, match="Expected string or dict for base_ds_config argument."): - _ = det_deepspeed.overwrite_deepspeed_config([1, 2], source_ds_config) + _ = det_ds.overwrite_deepspeed_config([1, 2], source_ds_config)