diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index ad34b4996f..c0f45ada4e 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1602,10 +1602,12 @@ def test_gpt_cuda_graph(dtype, bs, model): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = TransformerLayer( + block_args = ( config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, + ) + block_kwargs = dict( layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, @@ -1617,7 +1619,11 @@ def test_gpt_cuda_graph(dtype, bs, model): output_layernorm=False, device="cuda", ) - graphed_block = copy.deepcopy(block) + block = TransformerLayer(*block_args, **block_kwargs) + graphed_block = TransformerLayer(*block_args, **block_kwargs) + with torch.no_grad(): + for param1, param2 in zip(block.parameters(), graphed_block.parameters()): + param2.copy_(param1) out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False) graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 4e9c74d396..5c32cf2103 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -16,6 +16,11 @@ CPUOffloadEnabled = False +def is_cpu_offload_enabled() -> bool: + """Check if CPU offloading is currently enabled.""" + return CPUOffloadEnabled + + class CpuOffloadSavedTensorHook: """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e9fb11e3b9..0e27c64e3f 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -6,6 +6,7 @@ from __future__ import annotations from contextlib import contextmanager, AbstractContextManager, ContextDecorator +from functools import lru_cache from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -125,6 +126,7 @@ def set_tensor_model_parallel_attributes( setattr(tensor, "partition_stride", stride) +@lru_cache def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: """Return world size for the distributed group.""" if not torch.distributed.is_initialized(): @@ -132,6 +134,7 @@ def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: return torch.distributed.get_world_size(group=group) +@lru_cache def get_distributed_rank(group: Optional[dist_group_type] = None) -> int: """Return my rank for the distributed group.""" assert torch.distributed.is_initialized(), "torch.distributed is not initialized." diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 1646847162..3f642dd3cf 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -109,7 +109,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: """Disable native AMP for bias_gelu_fused_""" with torch.cuda.amp.autocast(enabled=False): - if bias.numel() != 0: + if bias is not None and bias.numel() != 0: return bias_gelu_fused_(inp, bias) return gelu_fused_(inp) @@ -119,7 +119,7 @@ def bgrad_dgelu_fused( ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """Disable native AMP for `bgrad_dgelu_fused_`""" with torch.cuda.amp.autocast(enabled=False): - if bias.numel() != 0: + if bias is not None and bias.numel() != 0: return bgrad_dgelu_fused_(grad_output, inp, bias) return None, dgelu_fused_(grad_output, inp) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 644af2c22c..85fae4798c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -11,7 +11,7 @@ import fcntl import struct from abc import ABC, abstractmethod -from typing import Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager import torch @@ -406,6 +406,36 @@ def __init__(self) -> None: self.fsdp_wrapped = False self.fsdp_group = None self._fp8_workspaces: Dict[str, Float8Tensor] = {} + self.activation_dtype: Optional[torch.dtype] = None + + # Fast getter for parameters + # Note: torch.nn.Module does not store parameters like normal + # attrs, but rather in a dict. When attempting to access, the + # module will raise an AttributeError in __getattribute__ and + # call a custom __getattr__. This is unnecessary overhead if + # we know we are accessing a parameter. + self._fast_get_param: Callable[str, torch.nn.Parameter] + self._fast_get_param = self.__dict__["_parameters"].get + + # Names of attributes that can be set quickly (see __setattr__ + # method) + _fast_setattr_names: Set[str] = { + "activation_dtype", + "fp8", + "fp8_initialized", + "fp8_calibration", + "fp8_parameters", + } + + def __setattr__(self, name: str, value: Any) -> None: + if name in TransformerEngineBaseModule._fast_setattr_names: + # torch.nn.Module has a custom __setattr__ that handles + # modules, parameters, and buffers. This is unnecessary + # overhead when setting plain attrs. + self.__dict__[name] = value + else: + # Default case + super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """Increase or decrease size of amax history based on given `length`. @@ -593,7 +623,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: return # All checks after this have already been performed once, thus skip - if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype: + if self.activation_dtype == inp.dtype: return dtype = inp.dtype @@ -708,10 +738,9 @@ def prepare_forward( FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): - if not allow_non_contiguous: - yield inp.contiguous() - else: - yield inp + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + yield inp if self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 14edd64249..ebb577182a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -39,8 +39,9 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..float8_tensor import Float8Tensor +from ..tensor import Float8Tensor, QuantizedTensor from ..export import is_in_onnx_export_mode +from ..cpu_offload import is_cpu_offload_enabled __all__ = ["GroupedLinear"] @@ -715,11 +716,11 @@ def forward( with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp: - weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_tensors = [self._fast_get_param(f"weight{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fp8: weight_tensors = [ - w.from_float8() if isinstance(w, Float8Tensor) else w for w in weight_tensors + w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors ] weight_tensors_fp8 = [None] * self.num_gemms @@ -746,8 +747,6 @@ def forward( skip_update_flag=skip_fp8_weight_update, ) - from ..cpu_offload import CPUOffloadEnabled - if torch.is_grad_enabled(): linear_fn = _GroupedLinear.apply args = [] @@ -763,7 +762,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, - CPUOffloadEnabled, + is_cpu_offload_enabled(), self.sequence_parallel, self.activation_dtype, self._offsets, diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 292fcd06de..dcf0cf62a0 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -12,7 +12,6 @@ from torch.nn import init import transformer_engine_torch as tex -from .base import TransformerEngineBaseModule from ..cpp_extensions import ( layernorm_fwd_inf, ) @@ -143,6 +142,7 @@ def __init__( ) ) self.sequence_parallel = sequence_parallel + self.activation_dtype: Optional[torch.dtype] = None self.reset_parameters(defer_init=(device == "meta")) @@ -186,8 +186,21 @@ def reset_parameters(self, defer_init=False) -> None: @no_torch_dynamo() def forward(self, inp: torch.Tensor) -> torch.Tensor: """LayerNorm FWD""" + # Set the activation type for AMP. - TransformerEngineBaseModule.set_activation_dtype(self, inp) + # Note: This will soon be deprecated with + # https://github.com/NVIDIA/TransformerEngine/pull/1033 + if torch.is_autocast_enabled(): + self.activation_dtype = torch.get_autocast_gpu_dtype() + elif self.activation_dtype != inp.dtype: + dtype = inp.dtype + for name, param in self.named_parameters(): + if param is not None: + assert dtype == param.dtype, ( + "Data types for parameters must match when outside of autocasted region. " + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) + self.activation_dtype = dtype if torch.is_grad_enabled(): fwd_fn = _LayerNorm.apply diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 92030a7f7a..8fc1ca24fb 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -46,6 +46,7 @@ from ..float8_tensor import Float8Tensor from ..export import is_in_onnx_export_mode from ..tensor import QuantizedTensor +from ..cpu_offload import is_cpu_offload_enabled __all__ = ["LayerNormLinear"] @@ -94,8 +95,9 @@ def forward( fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible - in_features = ln_weight.numel() - assert inp.shape[-1] == in_features, "GEMM not possible" + out_features, in_features = weight.shape + inp_shape = inp.shape + assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat) @@ -339,7 +341,7 @@ def forward( ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape + ctx.inp_shape = inp_shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group ctx.tp_size = tp_size @@ -369,7 +371,7 @@ def forward( out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP - out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) + out = out.view(-1, *inp_shape[1:-1], out_features) if return_layernorm_output: if return_layernorm_output_gathered: @@ -1149,7 +1151,7 @@ def forward( with self.prepare_forward(inp, is_first_microbatch) as inp: # Get concatenated weight and bias tensors - unfused_weights = [getattr(self, name) for name in self.weight_names] + unfused_weights = [self._fast_get_param(name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: @@ -1160,11 +1162,9 @@ def forward( unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat( - [getattr(self, name) for name in self.bias_names], - ) + bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names]) else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused + bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused # Initialize FP8 weights if needed weight_fp8 = None @@ -1190,8 +1190,6 @@ def forward( skip_update_flag=skip_fp8_weight_update, ) - from ..cpu_offload import CPUOffloadEnabled - if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply args = [] @@ -1200,8 +1198,8 @@ def forward( args = [None] args += ( inp, - self.layer_norm_weight, - self.layer_norm_bias, + self._fast_get_param("layer_norm_weight"), + self._fast_get_param("layer_norm_bias"), weight_tensor, weight_fp8, bias_tensor, @@ -1212,7 +1210,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, - CPUOffloadEnabled, + is_cpu_offload_enabled(), self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6d5609ccd2..1ae0d66d78 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -54,6 +54,7 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor from ._common import _apply_normalization +from ..cpu_offload import is_cpu_offload_enabled __all__ = ["LayerNormMLP"] @@ -124,7 +125,8 @@ def forward( ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() - assert inp.shape[-1] == in_features, "GEMM not possible" + inp_shape = inp.shape + assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat) @@ -433,7 +435,8 @@ def forward( ln_weight.weight_offloading = True fc1_weight.weight_offloading = True fc2_weight.weight_offloading = True - fc1_bias.weight_offloading = True + if fc1_bias is not None: + fc1_bias.weight_offloading = True inputmat.activation_offloading = True if normalization == "LayerNorm": @@ -487,7 +490,7 @@ def forward( ctx.use_fc2_bias = use_fc2_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape + ctx.inp_shape = inp_shape ctx.tp_group = tp_group ctx.tp_size = tp_size ctx.bias_gelu_nvfusion = bias_gelu_nvfusion @@ -519,11 +522,11 @@ def forward( fc2_out, _ = allreduce(fc2_out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP - fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1]) + fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1]) if return_layernorm_output: if return_layernorm_output_gathered: - shape = list(inp.shape) + shape = list(inp_shape) shape[0] *= tp_size return fc2_out, ln_out_return.view(shape) return fc2_out, ln_out_return.view_as(inp) @@ -1470,8 +1473,10 @@ def forward( with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: # Get weight tensors - fc1_weight = self.fc1_weight - fc2_weight = self.fc2_weight + fc1_weight = self._fast_get_param("fc1_weight") + fc1_bias = self._fast_get_param("fc1_bias") + fc2_weight = self._fast_get_param("fc2_weight") + fc2_bias = self._fast_get_param("fc2_bias") if not self.fp8: if isinstance(fc1_weight, Float8Tensor): fc1_weight = fc1_weight.from_float8() @@ -1524,8 +1529,6 @@ def forward( if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): self.bias_gelu_nvfusion = False - from ..cpu_offload import CPUOffloadEnabled - if torch.is_grad_enabled(): fwd_fn = _LayerNormMLP.apply args = [] @@ -1534,15 +1537,15 @@ def forward( args = [None] args += ( inp, - self.layer_norm_weight, - self.layer_norm_bias, + self._fast_get_param("layer_norm_weight"), + self._fast_get_param("layer_norm_bias"), fc1_weight, fc1_weight_fp8, - self.fc1_bias, + fc1_bias, self.use_bias, fc2_weight, fc2_weight_fp8, - self.fc2_bias, + fc2_bias, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, @@ -1550,7 +1553,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, - CPUOffloadEnabled, + is_cpu_offload_enabled(), self.tp_group, self.tp_size, self.sequence_parallel, @@ -1580,12 +1583,12 @@ def forward( out, ln_out = out if self.gemm_bias_unfused_add: - out = out + cast_if_needed(self.fc2_bias, self.activation_dtype) + out = out + cast_if_needed(fc2_bias, self.activation_dtype) if self.return_bias: if self.return_layernorm_output: - return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out - return out, cast_if_needed(self.fc2_bias, self.activation_dtype) + return out, cast_if_needed(fc2_bias, self.activation_dtype), ln_out + return out, cast_if_needed(fc2_bias, self.activation_dtype) if self.return_layernorm_output: return out, ln_out return out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8e19a65a28..d6406f6119 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -48,6 +48,7 @@ from ..float8_tensor import Float8Tensor from ..export import is_in_onnx_export_mode from ..tensor import QuantizedTensor +from ..cpu_offload import is_cpu_offload_enabled __all__ = ["Linear"] @@ -87,8 +88,9 @@ def forward( is_input_fp8 = isinstance(inp, Float8Tensor) # Make sure input dimensions are compatible - in_features = weight.shape[-1] - assert inp.shape[-1] == in_features, "GEMM not possible" + out_features, in_features = weight.shape + inp_shape = inp.shape + assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view(-1, in_features) if fp8: assert_dim_for_fp8_exec(inputmat) @@ -180,7 +182,7 @@ def forward( out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = weight_fp8.size(0) + dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_atomic_gemm(): @@ -200,7 +202,7 @@ def forward( ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) else: dim_size = list(inputmat_total.size()) - dim_size[1] = weight_fp8.size(0) + dim_size[1] = out_features out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) _ = fp8_gemm( @@ -260,7 +262,7 @@ def forward( out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) - dim_size[1] = weight.size(0) + dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P @@ -268,7 +270,7 @@ def forward( ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(inputmat_total.size()) - dim_size[1] = weight.size(0) + dim_size[1] = out_features out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) _ = gemm( @@ -334,7 +336,7 @@ def forward( ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape + ctx.inp_shape = inp_shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group ctx.ub_overlap_ag = ub_overlap_ag @@ -358,7 +360,7 @@ def forward( out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP - return out.view(-1, *inp.shape[1:-1], out.shape[-1]) + return out.view(-1, *inp_shape[1:-1], out_features) @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: @@ -941,7 +943,7 @@ def forward( ) as inp: # Get concatenated weight and bias tensors - unfused_weights = [getattr(self, name) for name in self.weight_names] + unfused_weights = [self._fast_get_param(name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: @@ -952,11 +954,9 @@ def forward( unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat( - [getattr(self, name) for name in self.bias_names], - ) + bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names]) else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused + bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused # Initialize FP8 weights if needed weight_fp8 = None @@ -983,8 +983,6 @@ def forward( fsdp_group=self.fsdp_group, ) - from ..cpu_offload import CPUOffloadEnabled - if torch.is_grad_enabled(): linear_fn = _Linear.apply args = [] @@ -1002,7 +1000,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, - CPUOffloadEnabled, + is_cpu_offload_enabled(), self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index d5dc400206..7bb16635f5 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -11,7 +11,6 @@ from torch.nn.parameter import Parameter from torch.nn import init -from .base import TransformerEngineBaseModule from .. import cpp_extensions as tex from ..jit import no_torch_dynamo from ..utils import cast_if_needed @@ -146,6 +145,7 @@ def __init__( ) ) self.sequence_parallel = sequence_parallel + self.activation_dtype: Optional[torch.dtype] = None self.reset_parameters(defer_init=(device == "meta")) @@ -185,7 +185,19 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: """RMSNorm FWD""" # Set the activation type for AMP. - TransformerEngineBaseModule.set_activation_dtype(self, inp) + # Note: This will soon be deprecated with + # https://github.com/NVIDIA/TransformerEngine/pull/1033 + if torch.is_autocast_enabled(): + self.activation_dtype = torch.get_autocast_gpu_dtype() + elif self.activation_dtype != inp.dtype: + dtype = inp.dtype + for name, param in self.named_parameters(): + if param is not None: + assert dtype == param.dtype, ( + "Data types for parameters must match when outside of autocasted region. " + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) + self.activation_dtype = dtype if torch.is_grad_enabled(): fwd_fn = _RMSNorm.apply diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 020d262be2..48761faa41 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -751,7 +751,7 @@ def forward( return output def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None): - if drop_path is None and bias.numel() != 0: + if drop_path is None and bias is not None and bias.numel() != 0: if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train @@ -763,7 +763,7 @@ def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None): with self.bias_dropout_add_exec_handler(): output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout) else: - if bias.numel() != 0: + if bias is not None and bias.numel() != 0: hidden_state = hidden_state + bias out = torch.nn.functional.dropout( hidden_state, p=self.hidden_dropout, training=self.training diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index d5145455b8..947c642c2c 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -218,8 +218,12 @@ def safely_set_viewless_tensor_data(tensor: torch.Tensor, new_data_tensor: torch def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: """Cast tensor to dtype""" + if tensor is None: + return None + if tensor.dtype == dtype: + return tensor with torch.enable_grad(): - return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype) + return tensor.to(dtype=dtype) def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: