Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Add utility for filtering out skpped tests in large paremtrization gr…
Browse files Browse the repository at this point in the history
…oups

ghstack-source-id: 275f276e73a1f6035b96ebe8901a7ed87f7ccf3a
Pull Request resolved: #303
  • Loading branch information
drisspg committed Jul 3, 2024
1 parent c57aa9e commit b1c0258
Showing 1 changed file with 75 additions and 69 deletions.
144 changes: 75 additions & 69 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re
import unittest
import warnings
from itertools import product

import pytest

Expand Down Expand Up @@ -52,6 +53,37 @@
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


def filtered_parametrize(param_list, filter_func=None):
"""
A decorator that works like pytest.mark.parametrize but filters out
unwanted parameter combinations.
:param param_list: A list of tuples, each containing (arg_name, [arg_values])
:param filter_func: A function that takes a dictionary of parameter names and values,
and returns True for valid combinations, False otherwise
"""

def decorator(func):
arg_names = [param[0] for param in param_list]
arg_values = [param[1] for param in param_list]

all_combinations = product(*arg_values)
if filter_func:
valid_combinations = [
combo
for combo in all_combinations
if filter_func(dict(zip(arg_names, combo)))
]
else:
valid_combinations = list(all_combinations)

return pytest.mark.parametrize(
argnames=arg_names, argvalues=valid_combinations
)(func)

return decorator


def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
Expand Down Expand Up @@ -230,17 +262,35 @@ def _test_linear_impl(
# verify initialization flags got updated
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
@pytest.mark.parametrize(
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
@staticmethod
def is_valid_combination(params):
if not params["emulate"]:
if not torch.cuda.is_available():
return False
if torch.cuda.get_device_capability() < (9, 0):
return False

if params["linear_type"] == LinearType.DYNAMIC:
return all(
params[key] == TensorScalingType.DYNAMIC
for key in ["scaling_type_x", "scaling_type_w", "scaling_type_dL_dY"]
)

return True

@filtered_parametrize(
[
("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]),
("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]),
("emulate", [True, False] if is_H100 else [True]),
("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
(
"scaling_type_dL_dY",
[TensorScalingType.DELAYED, TensorScalingType.DYNAMIC],
),
],
filter_func=is_valid_combination,
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_linear_nobias(
Expand All @@ -252,28 +302,6 @@ def test_linear_nobias(
scaling_type_w: TensorScalingType,
scaling_type_dL_dY: TensorScalingType,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
if linear_type is LinearType.DYNAMIC:
# Only test one combination of scaling types, as they are a no-op
# for Float8DynamicLinear. It would be cleaner to split into two
# tests, but IMO not worth it since Float8DynamicLinear will be
# deleted soon
is_all_dynamic = (
scaling_type_x is TensorScalingType.DYNAMIC
and scaling_type_w is TensorScalingType.DYNAMIC
and scaling_type_dL_dY is TensorScalingType.DYNAMIC
)
if not is_all_dynamic:
pytest.skip()

x = torch.randn(*x_shape, device="cuda")
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
self._test_linear_impl(
Expand All @@ -286,20 +314,20 @@ def test_linear_nobias(
scaling_type_dL_dY,
)

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
@pytest.mark.parametrize(
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
@filtered_parametrize(
[
("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]),
("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]),
("emulate", [True, False] if is_H100 else [True]),
("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
(
"scaling_type_dL_dY",
[TensorScalingType.DELAYED, TensorScalingType.DYNAMIC],
),
("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]),
],
filter_func=is_valid_combination,
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_linear_bias(
Expand All @@ -312,28 +340,6 @@ def test_linear_bias(
emulate: bool,
linear_dtype: torch.dtype,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
if linear_type is LinearType.DYNAMIC:
# Only test one combination of scaling types, as they are a no-op
# for Float8DynamicLinear. It would be cleaner to split into two
# tests, but IMO not worth it since Float8DynamicLinear will be
# deleted soon
is_all_dynamic = (
scaling_type_x is TensorScalingType.DYNAMIC
and scaling_type_w is TensorScalingType.DYNAMIC
and scaling_type_dL_dY is TensorScalingType.DYNAMIC
)
if not is_all_dynamic:
pytest.skip()

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
self._test_linear_impl(
Expand Down

0 comments on commit b1c0258

Please sign in to comment.