diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b59f3cb9c..ee8cedebe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -112,6 +112,9 @@ jobs: XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain" XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain" + - name: Test custom prng + run: | + JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py examples: diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 5ff91f03a..0436b10ed 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -10,7 +10,6 @@ import operator from jax import grad, jacfwd, numpy as jnp, random, vmap -from jax.random import KeyArray from jax.tree_util import tree_map from numpyro import handlers @@ -370,10 +369,10 @@ def _update_force(attr_force, rep_force, jac): ) return jnp.linalg.norm(particle_grads), res_grads - def init(self, rng_key: KeyArray, *args, **kwargs): + def init(self, rng_key, *args, **kwargs): """Register random variable transformations, constraints and determine initialize positions of the particles. - :param KeyArray rng_key: Random number generator seed. + :param rng_key: Random number generator seed. :param args: Arguments to the model / guide. :param kwargs: Keyword arguments to the model / guide. :return: initial :data:`SteinVIState` diff --git a/numpyro/contrib/tfp/mcmc.py b/numpyro/contrib/tfp/mcmc.py index f1022574d..bce3b8da7 100644 --- a/numpyro/contrib/tfp/mcmc.py +++ b/numpyro/contrib/tfp/mcmc.py @@ -14,7 +14,7 @@ from numpyro.infer import init_to_uniform from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import initialize_model -from numpyro.util import identity +from numpyro.util import identity, is_prng_key TFPKernelState = namedtuple("TFPKernelState", ["z", "kernel_results", "rng_key"]) @@ -174,7 +174,7 @@ def init( self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={} ): # non-vectorized - if rng_key.ndim == 1: + if is_prng_key(rng_key): rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: @@ -190,7 +190,7 @@ def init( " `target_log_prob_fn`." ) - if rng_key.ndim == 1: + if is_prng_key(rng_key): init_state = self._init_fn(init_params, rng_key) else: # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index 5b4233fb0..f0c7b93c7 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -14,7 +14,8 @@ ZeroInflatedDistribution, ) from numpyro.distributions.distribution import Distribution -from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample +from numpyro.distributions.util import promote_shapes, validate_sample +from numpyro.util import is_prng_key def _log_beta_1(alpha, value): diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index a0ea4349f..5df46dade 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -64,7 +64,6 @@ betaincinv, cholesky_of_inverse, gammaincinv, - is_prng_key, lazy_property, matrix_to_tril_vec, promote_shapes, @@ -72,6 +71,7 @@ validate_sample, vec_to_tril_matrix, ) +from numpyro.util import is_prng_key class AsymmetricLaplace(Distribution): diff --git a/numpyro/distributions/copula.py b/numpyro/distributions/copula.py index 98fec536a..383bee2c0 100644 --- a/numpyro/distributions/copula.py +++ b/numpyro/distributions/copula.py @@ -6,12 +6,8 @@ import numpyro.distributions.constraints as constraints from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal from numpyro.distributions.distribution import Distribution -from numpyro.distributions.util import ( - clamp_probs, - is_prng_key, - lazy_property, - validate_sample, -) +from numpyro.distributions.util import clamp_probs, lazy_property, validate_sample +from numpyro.util import is_prng_key class GaussianCopula(Distribution): diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 67a91a8d1..841dd2194 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -16,14 +16,13 @@ from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import ( - is_prng_key, lazy_property, promote_shapes, safe_normalize, validate_sample, von_mises_centered, ) -from numpyro.util import while_loop +from numpyro.util import is_prng_key, while_loop def _numel(shape): diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 0952c53da..e5fbb2f88 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -41,13 +41,12 @@ binomial, categorical, clamp_probs, - is_prng_key, lazy_property, multinomial, promote_shapes, validate_sample, ) -from numpyro.util import not_jax_tracer +from numpyro.util import is_prng_key, not_jax_tracer def _to_probs_bernoulli(logits): diff --git a/numpyro/distributions/mixtures.py b/numpyro/distributions/mixtures.py index f5117d1f3..de29ad988 100644 --- a/numpyro/distributions/mixtures.py +++ b/numpyro/distributions/mixtures.py @@ -7,7 +7,8 @@ from numpyro.distributions import Distribution, constraints from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs -from numpyro.distributions.util import is_prng_key, validate_sample +from numpyro.distributions.util import validate_sample +from numpyro.util import is_prng_key def Mixture(mixing_distribution, component_distributions, *, validate_args=None): diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index eeb2de398..a43b6268b 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -19,11 +19,11 @@ from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import ( clamp_probs, - is_prng_key, lazy_property, promote_shapes, validate_sample, ) +from numpyro.util import is_prng_key class LeftTruncatedDistribution(Distribution): diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 220961ef6..f59e336db 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -4,6 +4,7 @@ from collections import namedtuple from functools import partial, update_wrapper import math +import warnings import numpy as np @@ -612,11 +613,11 @@ def safe_normalize(x, *, p=2): return x -# src: https://github.com/google/jax/blob/5a41779fbe12ba7213cd3aa1169d3b0ffb02a094/jax/_src/random.py#L95 def is_prng_key(key): - if isinstance(key, jax.random.PRNGKeyArray): - return key.shape == () + warnings.warn("Please use numpyro.util.is_prng_key.", DeprecationWarning) try: + if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key): + return key.shape == () return key.shape == (2,) and key.dtype == np.uint32 except AttributeError: return False diff --git a/numpyro/handlers.py b/numpyro/handlers.py index c4a9f7235..23357187c 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -93,7 +93,7 @@ apply_stack, plate, ) -from numpyro.util import find_stack_level, not_jax_tracer +from numpyro.util import find_stack_level, is_prng_key, not_jax_tracer __all__ = [ "block", @@ -705,15 +705,15 @@ class seed(Messenger): """ def __init__(self, fn=None, rng_seed=None, hide_types=None): - if isinstance(rng_seed, int) or ( - isinstance(rng_seed, (np.ndarray, jnp.ndarray)) and not jnp.shape(rng_seed) + if not is_prng_key(rng_seed) and ( + isinstance(rng_seed, int) + or ( + isinstance(rng_seed, (np.ndarray, jnp.ndarray)) + and not jnp.shape(rng_seed) + ) ): rng_seed = random.PRNGKey(rng_seed) - if not ( - isinstance(rng_seed, (np.ndarray, jnp.ndarray)) - and rng_seed.dtype == jnp.uint32 - and rng_seed.shape == (2,) - ): + if not is_prng_key(rng_seed): raise TypeError("Incorrect type for rng_seed: {}".format(type(rng_seed))) self.rng_key = rng_seed self.hide_types = [] if hide_types is None else hide_types diff --git a/numpyro/infer/barker.py b/numpyro/infer/barker.py index d58a8ff6e..5496e7baa 100644 --- a/numpyro/infer/barker.py +++ b/numpyro/infer/barker.py @@ -14,7 +14,7 @@ from numpyro.infer.initialization import init_to_uniform from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import initialize_model -from numpyro.util import identity +from numpyro.util import identity, is_prng_key BarkerMHState = namedtuple( "BarkerMHState", @@ -170,7 +170,7 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): self._num_warmup = num_warmup # TODO (low-priority): support chain_method="vectorized", i.e. rng_key is a batch of keys - assert rng_key.shape == (2,), ( + assert is_prng_key(rng_key), ( "BarkerMH only supports chain_method='parallel' or chain_method='sequential'." " Please put in a feature request if you think it would be useful to be able " "to use BarkerMH in vectorized mode." diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index aa1aae392..8aac9e4d0 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -20,7 +20,7 @@ ) from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model -from numpyro.util import cond, fori_loop, identity +from numpyro.util import cond, fori_loop, identity, is_prng_key HMCState = namedtuple( "HMCState", @@ -703,7 +703,7 @@ def init( self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={} ): # non-vectorized - if rng_key.ndim == 1: + if is_prng_key(rng_key): rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: @@ -749,7 +749,7 @@ def init( model_kwargs=model_kwargs, rng_key=rng_key, ) - if rng_key.ndim == 1: + if is_prng_key(rng_key): init_state = hmc_init_fn(init_params, rng_key) else: # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 082b51460..92f461a73 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -14,7 +14,13 @@ from jax.tree_util import tree_flatten, tree_map from numpyro.diagnostics import print_summary -from numpyro.util import cached_by, find_stack_level, fori_collect, identity +from numpyro.util import ( + cached_by, + find_stack_level, + fori_collect, + identity, + is_prng_key, +) __all__ = [ "MCMCKernel", @@ -418,7 +424,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): sample_fn, postprocess_fn = self._get_cached_fns() diagnostics = ( lambda x: self.sampler.get_diagnostics_str(x[0]) - if rng_key.ndim == 1 + if is_prng_key(rng_key) else "" ) # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) @@ -595,7 +601,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): self._args = args self._kwargs = kwargs init_state = self._get_cached_init_state(rng_key, args, kwargs) - if self.num_chains > 1 and rng_key.ndim == 1: + if self.num_chains > 1 and is_prng_key(rng_key): rng_key = random.split(rng_key, self.num_chains) if self._warmup_state is not None: diff --git a/numpyro/infer/sa.py b/numpyro/infer/sa.py index 085cc18d4..91f963f41 100644 --- a/numpyro/infer/sa.py +++ b/numpyro/infer/sa.py @@ -12,7 +12,7 @@ from numpyro.distributions.util import cholesky_update from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import init_to_uniform, initialize_model -from numpyro.util import identity +from numpyro.util import identity, is_prng_key def _get_proposal_loc_and_scale(samples, loc, scale, new_sample): @@ -331,7 +331,7 @@ def init( self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={} ): # non-vectorized - if rng_key.ndim == 1: + if is_prng_key(rng_key): rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: @@ -358,7 +358,7 @@ def init( model_args=model_args, model_kwargs=model_kwargs, ) - if rng_key.ndim == 1: + if is_prng_key(rng_key): init_state = sa_init_fn(init_params, rng_key) else: init_state = vmap(sa_init_fn)(init_params, rng_key) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index e3d269a6e..916545cda 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -25,6 +25,7 @@ from numpyro.util import ( _validate_model, find_stack_level, + is_prng_key, not_jax_tracer, soft_vmap, while_loop, @@ -435,7 +436,7 @@ def _find_valid_params(rng_key, exit_early=False): return (init_params, pe, z_grad), is_valid # Handle possible vectorization - if rng_key.ndim == 1: + if is_prng_key(rng_key): (init_params, pe, z_grad), is_valid = _find_valid_params( rng_key, exit_early=True ) @@ -644,7 +645,7 @@ def initialize_model( """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute( - seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), + seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]), substitute_fn=init_strategy, ) ( @@ -816,9 +817,10 @@ def single_prediction(val): return {name: value for name, value in pred_samples.items() if name in sites} num_samples = int(np.prod(batch_shape)) + key_shape = rng_key.shape if num_samples > 1: rng_key = random.split(rng_key, num_samples) - rng_key = rng_key.reshape((*batch_shape, 2)) + rng_key = rng_key.reshape(batch_shape + key_shape) chunk_size = num_samples if parallel else 1 return soft_vmap( single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py index 5386435f7..a53f26ae8 100644 --- a/numpyro/ops/provenance.py +++ b/numpyro/ops/provenance.py @@ -7,7 +7,7 @@ from jax.experimental.pjit import pjit_p from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic from jax.interpreters.pxla import xla_pmap_p -import jax.linear_util as lu +import jax.extend.linear_util as lu import jax.numpy as jnp diff --git a/numpyro/util.py b/numpyro/util.py index 09a338588..ccc8c09ae 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -141,6 +141,15 @@ def fori_loop(lower, upper, body_fun, init_val): return lax.fori_loop(lower, upper, body_fun, init_val) +def is_prng_key(key): + try: + if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key): + return key.shape == () + return key.shape == (2,) and key.dtype == np.uint32 + except AttributeError: + return False + + def not_jax_tracer(x): """ Checks if `x` is not an array generated inside `jit`, `pmap`, `vmap`, or `lax_control_flow`. diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index 4af95a237..fd297c420 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -273,7 +273,7 @@ def model(obs): expected_shape = (num_particles, *np.shape(inner_param["value"])) assert init_value.shape == expected_shape if "auto_loc" in name or name == "b": - assert np.alltrue(init_value != np.zeros(expected_shape)) + assert np.all(init_value != np.zeros(expected_shape)) assert np.unique(init_value).shape == init_value.reshape(-1).shape elif "scale" in name: assert_allclose(init_value[init_value != 0.0], 0.1, rtol=1e-6) @@ -311,7 +311,7 @@ def model(obs): expected_shape = (num_particles, latent_dim) assert expected_shape == init_value.shape - assert np.alltrue(init_value != np.zeros(expected_shape)) + assert np.all(init_value != np.zeros(expected_shape)) assert np.unique(init_value).shape == init_value.reshape(-1).shape diff --git a/test/contrib/einstein/test_steinvi_util.py b/test/contrib/einstein/test_steinvi_util.py index 532f7e18c..617effb71 100644 --- a/test/contrib/einstein/test_steinvi_util.py +++ b/test/contrib/einstein/test_steinvi_util.py @@ -38,7 +38,7 @@ @pytest.mark.parametrize("m", matrices) def test_posdef(m): pd_m = posdef(m) - assert jnp.alltrue(jnp.linalg.eigvals(pd_m) > 0) + assert jnp.all(jnp.linalg.eigvals(pd_m) > 0) @pytest.mark.parametrize("batch_shape", [(), (2,), (3, 1)]) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 65b33b99f..368fde593 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -22,7 +22,7 @@ from numpyro.infer.reparam import TransformReparam from numpyro.infer.sa import _get_proposal_loc_and_scale, _numpy_delete from numpyro.infer.util import initialize_model -from numpyro.util import fori_collect +from numpyro.util import fori_collect, is_prng_key @pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH]) @@ -406,8 +406,14 @@ def model(data): tree_all( tree_map( partial(assert_allclose, atol=1e-4, rtol=1e-4), - mcmc1.post_warmup_state, - mcmc.post_warmup_state, + tree_map( + lambda x: random.key_data(x) if is_prng_key(x) else x, + mcmc1.post_warmup_state, + ), + tree_map( + lambda x: random.key_data(x) if is_prng_key(x) else x, + mcmc.post_warmup_state, + ), ) ) @@ -911,7 +917,7 @@ def model(cov): cov = np.zeros((5, 5)) cov[:2, :2] = w_cov cov[2:4, 2:4] = xy_cov - cov[4, 4] = z_var + cov[4, 4] = z_var[0] kernel = NUTS(model, dense_mass=[("w",), ("x", "y")]) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1) diff --git a/test/ops/test_provenance.py b/test/ops/test_provenance.py index 22d41dbda..5f9361422 100644 --- a/test/ops/test_provenance.py +++ b/test/ops/test_provenance.py @@ -8,7 +8,7 @@ import jax from jax.api_util import flatten_fun_nokwargs import jax.core as core -import jax.linear_util as lu +import jax.extend.linear_util as lu import jax.numpy as jnp from numpyro.ops.provenance import eval_provenance