From b1c025859a4daccb5459307b9c5c5a3b75bde077 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 3 Jul 2024 11:53:43 -0700 Subject: [PATCH] Add utility for filtering out skpped tests in large paremtrization groups ghstack-source-id: 275f276e73a1f6035b96ebe8901a7ed87f7ccf3a Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/303 --- test/test_base.py | 144 ++++++++++++++++++++++++---------------------- 1 file changed, 75 insertions(+), 69 deletions(-) diff --git a/test/test_base.py b/test/test_base.py index 1fee3bc..86c1462 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -9,6 +9,7 @@ import re import unittest import warnings +from itertools import product import pytest @@ -52,6 +53,37 @@ is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) +def filtered_parametrize(param_list, filter_func=None): + """ + A decorator that works like pytest.mark.parametrize but filters out + unwanted parameter combinations. + + :param param_list: A list of tuples, each containing (arg_name, [arg_values]) + :param filter_func: A function that takes a dictionary of parameter names and values, + and returns True for valid combinations, False otherwise + """ + + def decorator(func): + arg_names = [param[0] for param in param_list] + arg_values = [param[1] for param in param_list] + + all_combinations = product(*arg_values) + if filter_func: + valid_combinations = [ + combo + for combo in all_combinations + if filter_func(dict(zip(arg_names, combo))) + ] + else: + valid_combinations = list(all_combinations) + + return pytest.mark.parametrize( + argnames=arg_names, argvalues=valid_combinations + )(func) + + return decorator + + def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._data == b._data).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -230,17 +262,35 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) - @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) - @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + @staticmethod + def is_valid_combination(params): + if not params["emulate"]: + if not torch.cuda.is_available(): + return False + if torch.cuda.get_device_capability() < (9, 0): + return False + + if params["linear_type"] == LinearType.DYNAMIC: + return all( + params[key] == TensorScalingType.DYNAMIC + for key in ["scaling_type_x", "scaling_type_w", "scaling_type_dL_dY"] + ) + + return True + + @filtered_parametrize( + [ + ("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]), + ("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]), + ("emulate", [True, False] if is_H100 else [True]), + ("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ( + "scaling_type_dL_dY", + [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], + ), + ], + filter_func=is_valid_combination, ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_nobias( @@ -252,28 +302,6 @@ def test_linear_nobias( scaling_type_w: TensorScalingType, scaling_type_dL_dY: TensorScalingType, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() - x = torch.randn(*x_shape, device="cuda") m_ref = nn.Linear(16, 32, bias=False, device="cuda") self._test_linear_impl( @@ -286,20 +314,20 @@ def test_linear_nobias( scaling_type_dL_dY, ) - @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) - @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) - @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] + @filtered_parametrize( + [ + ("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]), + ("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]), + ("emulate", [True, False] if is_H100 else [True]), + ("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ( + "scaling_type_dL_dY", + [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], + ), + ("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]), + ], + filter_func=is_valid_combination, ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_bias( @@ -312,28 +340,6 @@ def test_linear_bias( emulate: bool, linear_dtype: torch.dtype, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) self._test_linear_impl(