Skip to content

Commit

Permalink
import jaxns only if double precision is enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
renecotyfanboy committed Oct 16, 2024
1 parent 95937ff commit 19ae8f0
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions test/contrib/test_nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import numpyro

try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam
if os.environ.get("JAX_ENABLE_X64"):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam

except ImportError:
pytestmark = pytest.mark.skip(reason="jaxns is not installed")

import numpyro.distributions as dist
from numpyro.distributions.transforms import AffineTransform, ExpTransform

Expand Down

0 comments on commit 19ae8f0

Please sign in to comment.