diff --git a/test/contrib/test_nested_sampling.py b/test/contrib/test_nested_sampling.py index d56e16504..a04369de0 100644 --- a/test/contrib/test_nested_sampling.py +++ b/test/contrib/test_nested_sampling.py @@ -14,12 +14,14 @@ import numpyro try: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam + if os.environ.get("JAX_ENABLE_X64"): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam except ImportError: pytestmark = pytest.mark.skip(reason="jaxns is not installed") + import numpyro.distributions as dist from numpyro.distributions.transforms import AffineTransform, ExpTransform