diff --git a/test/test_utils.py b/test/test_utils.py index af5dc09985c..6537c19ff54 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -15,10 +15,13 @@ import torch from _utils_internal import capture_log_records, get_default_devices +from packaging import version from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) + @pytest.mark.parametrize("value", ["True", "1", "true"]) def test_get_binary_env_var_positive(value): @@ -382,6 +385,9 @@ def test_rng_decorator(device): # Check that 'capture_log_records' captures records emitted when torch # recompiles a function. +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" +) def test_capture_log_records_recompile(): torch.compiler.reset()