Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor (DTensor) and provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism.
Warning
Tensor Parallelism APIs are experimental and subject to change.
The entrypoint to parallelize your nn.Module
using Tensor Parallelism is:
.. automodule:: torch.distributed.tensor.parallel
.. currentmodule:: torch.distributed.tensor.parallel
.. autofunction:: parallelize_module
Tensor Parallelism supports the following parallel styles:
.. autoclass:: torch.distributed.tensor.parallel.ColwiseParallel :members: :undoc-members:
.. autoclass:: torch.distributed.tensor.parallel.RowwiseParallel :members: :undoc-members:
.. autoclass:: torch.distributed.tensor.parallel.SequenceParallel :members: :undoc-members:
To simply configure the nn.Module's inputs and outputs with DTensor layouts
and perform necessary layout redistributions, without distribute the module
parameters to DTensors, the following ParallelStyle
s can be used in
the parallelize_plan
when calling parallelize_module
:
.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput :members: :undoc-members:
.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleOutput :members: :undoc-members:
Note
when using the Shard(dim)
as the input/output layouts for the above
ParallelStyle
s, we assume the input/output activation tensors are evenly sharded on
the tensor dimension dim
on the DeviceMesh
that TP operates on. For instance,
since RowwiseParallel
accepts input that is sharded on the last dimension, it assumes
the input tensor has already been evenly sharded on the last dimension. For the case of uneven
sharded activation tensors, one could pass in DTensor directly to the partitioned modules,
and use use_local_output=False
to return DTensor after each ParallelStyle
, where
DTensor could track the uneven sharding information.
For models like Transformer, we recommend users to use ColwiseParallel
and RowwiseParallel
together in the parallelize_plan for achieve the desired
sharding for the entire model (i.e. Attention and MLP).
Parallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager:
.. autofunction:: torch.distributed.tensor.parallel.loss_parallel
Warning
The loss_parallel API is experimental and subject to change.