Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add utility for filtering out skipped tests in large cross-product groups #303

Open
wants to merge 4 commits into
base: gh/drisspg/3/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def cast_x_to_float8(
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
autocast_dtype = torch.get_autocast_dtype("cuda")
x = x.to(autocast_dtype)

if self.scaling_type_x is TensorScalingType.DELAYED:
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
fp8_layers = get_float8_layers(model)

if len(fp8_layers) == 0:
log.warn(
log.warning(
"Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers"
)
return
Expand Down
80 changes: 54 additions & 26 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import re
import unittest
import warnings
from itertools import product
from typing import Any, Callable, Dict, List, Optional, Tuple

import pytest

Expand Down Expand Up @@ -50,6 +52,42 @@
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


def filtered_parametrize(
param_list: List[Tuple[str, List[Any]]],
filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None,
):
"""
A decorator that works like pytest.mark.parametrize but filters out
unwanted parameter combinations.

Args:
param_list: A list of tuples, each containing (arg_name, [arg_values])
filter_func: A function that takes a dictionary of parameter names and values,
and returns True for valid combinations, False otherwise

"""

def decorator(func):
arg_names = [param[0] for param in param_list]
arg_values = [param[1] for param in param_list]

all_combinations = product(*arg_values)
if filter_func:
valid_combinations = [
combo
for combo in all_combinations
if filter_func(dict(zip(arg_names, combo)))
]
else:
valid_combinations = list(all_combinations)

return pytest.mark.parametrize(
argnames=arg_names, argvalues=valid_combinations
)(func)

return decorator


def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
Expand Down Expand Up @@ -243,48 +281,38 @@ def test_linear(
scaling_type_x: TensorScalingType,
scaling_type_w: TensorScalingType,
scaling_type_dL_dY: TensorScalingType,
linear_dtype: torch.dtype,
linear_bias: bool,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
x = torch.randn(*x_shape, device="cuda")
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
self._test_linear_impl(
x,
m_ref,
linear_type,
emulate,
scaling_type_x,
scaling_type_w,
scaling_type_dL_dY,
)

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]

@filtered_parametrize(
[
("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]),
("emulate", [True, False] if is_H100 else [True]),
("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
(
"scaling_type_dL_dY",
[TensorScalingType.DELAYED, TensorScalingType.DYNAMIC],
),
("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]),
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_autocast_outputs(
self,
emulate: bool,
linear_dtype: torch.dtype,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()

m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
kwargs = {
Expand Down
Loading