Skip to content

Commit

Permalink
[PyTorch] Minor optimizations to reduce CPU overheads in modules (#1191)
Browse files Browse the repository at this point in the history
* CPU perf optimization in linear autograd function

Avoid enable_grad context when possible in cast function. Cache distributed group properties.

Signed-off-by: Tim Moon <[email protected]>

* CPU perf optimization in prepare_forward function

Avoid torch.nn.Module impl of __setattr__.

Signed-off-by: Tim Moon <[email protected]>

* Avoid module import in TE module forwards

Signed-off-by: Tim Moon <[email protected]>

* Use fast getter for params

Signed-off-by: Tim Moon <[email protected]>

* Reuse tensor dims in linear autograd func

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply optimizations to grouped linear

Signed-off-by: Tim Moon <[email protected]>

* Debug test failures

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug test failures

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Avoid deepcopy in tests

Signed-off-by: Tim Moon <[email protected]>

* Move _fast_setattr logic to __setattr__ method

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
timmoon10 and pre-commit-ci[bot] authored Oct 4, 2024
1 parent 10cceae commit 9d976bc
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 70 deletions.
10 changes: 8 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,10 +1602,12 @@ def test_gpt_cuda_graph(dtype, bs, model):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

block = TransformerLayer(
block_args = (
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
)
block_kwargs = dict(
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
Expand All @@ -1617,7 +1619,11 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layernorm=False,
device="cuda",
)
graphed_block = copy.deepcopy(block)
block = TransformerLayer(*block_args, **block_kwargs)
graphed_block = TransformerLayer(*block_args, **block_kwargs)
with torch.no_grad():
for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
param2.copy_(param1)

out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
CPUOffloadEnabled = False


def is_cpu_offload_enabled() -> bool:
"""Check if CPU offloading is currently enabled."""
return CPUOffloadEnabled


class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings

Expand Down Expand Up @@ -125,13 +126,15 @@ def set_tensor_model_parallel_attributes(
setattr(tensor, "partition_stride", stride)


@lru_cache
def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
"""Return world size for the distributed group."""
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size(group=group)


@lru_cache
def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
"""Return my rank for the distributed group."""
assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)

Expand All @@ -119,7 +119,7 @@ def bgrad_dgelu_fused(
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)

Expand Down
41 changes: 35 additions & 6 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import fcntl
import struct
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager

import torch
Expand Down Expand Up @@ -406,6 +406,36 @@ def __init__(self) -> None:
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
self.activation_dtype: Optional[torch.dtype] = None

# Fast getter for parameters
# Note: torch.nn.Module does not store parameters like normal
# attrs, but rather in a dict. When attempting to access, the
# module will raise an AttributeError in __getattribute__ and
# call a custom __getattr__. This is unnecessary overhead if
# we know we are accessing a parameter.
self._fast_get_param: Callable[str, torch.nn.Parameter]
self._fast_get_param = self.__dict__["_parameters"].get

# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names: Set[str] = {
"activation_dtype",
"fp8",
"fp8_initialized",
"fp8_calibration",
"fp8_parameters",
}

def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)

def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`.
Expand Down Expand Up @@ -593,7 +623,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
return

# All checks after this have already been performed once, thus skip
if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
if self.activation_dtype == inp.dtype:
return

dtype = inp.dtype
Expand Down Expand Up @@ -708,10 +738,9 @@ def prepare_forward(
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)

with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
if not allow_non_contiguous:
yield inp.contiguous()
else:
yield inp
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp

if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
Expand Down
11 changes: 5 additions & 6 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ..tensor import Float8Tensor, QuantizedTensor
from ..export import is_in_onnx_export_mode
from ..cpu_offload import is_cpu_offload_enabled

__all__ = ["GroupedLinear"]

Expand Down Expand Up @@ -715,11 +716,11 @@ def forward(

with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp:

weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
weight_tensors = [self._fast_get_param(f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
weight_tensors = [
w.from_float8() if isinstance(w, Float8Tensor) else w for w in weight_tensors
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
]

weight_tensors_fp8 = [None] * self.num_gemms
Expand All @@ -746,8 +747,6 @@ def forward(
skip_update_flag=skip_fp8_weight_update,
)

from ..cpu_offload import CPUOffloadEnabled

if torch.is_grad_enabled():
linear_fn = _GroupedLinear.apply
args = []
Expand All @@ -763,7 +762,7 @@ def forward(
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
self._offsets,
Expand Down
17 changes: 15 additions & 2 deletions transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch.nn import init

import transformer_engine_torch as tex
from .base import TransformerEngineBaseModule
from ..cpp_extensions import (
layernorm_fwd_inf,
)
Expand Down Expand Up @@ -143,6 +142,7 @@ def __init__(
)
)
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None

self.reset_parameters(defer_init=(device == "meta"))

Expand Down Expand Up @@ -186,8 +186,21 @@ def reset_parameters(self, defer_init=False) -> None:
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""

# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype

if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
Expand Down
26 changes: 12 additions & 14 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
from ..cpu_offload import is_cpu_offload_enabled

__all__ = ["LayerNormLinear"]

Expand Down Expand Up @@ -94,8 +95,9 @@ def forward(
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
out_features, in_features = weight.shape
inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat)
Expand Down Expand Up @@ -339,7 +341,7 @@ def forward(
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.inp_shape = inp_shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
Expand Down Expand Up @@ -369,7 +371,7 @@ def forward(
out, _ = allreduce(out, tp_group)

# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
out = out.view(-1, *inp_shape[1:-1], out_features)

if return_layernorm_output:
if return_layernorm_output_gathered:
Expand Down Expand Up @@ -1149,7 +1151,7 @@ def forward(
with self.prepare_forward(inp, is_first_microbatch) as inp:

# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
unfused_weights = [self._fast_get_param(name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
Expand All @@ -1160,11 +1162,9 @@ def forward(
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
)
bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused

# Initialize FP8 weights if needed
weight_fp8 = None
Expand All @@ -1190,8 +1190,6 @@ def forward(
skip_update_flag=skip_fp8_weight_update,
)

from ..cpu_offload import CPUOffloadEnabled

if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
Expand All @@ -1200,8 +1198,8 @@ def forward(
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self._fast_get_param("layer_norm_weight"),
self._fast_get_param("layer_norm_bias"),
weight_tensor,
weight_fp8,
bias_tensor,
Expand All @@ -1212,7 +1210,7 @@ def forward(
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
Expand Down
Loading

0 comments on commit 9d976bc

Please sign in to comment.