Skip to content

Commit

Permalink
Support custom prng key (#1642)
Browse files Browse the repository at this point in the history
* support custom prng key

* run black

* test custom prng in CI

* fix some deprecation warnings
  • Loading branch information
fehiepsi authored Sep 21, 2023
1 parent ca96eca commit 115c4d3
Show file tree
Hide file tree
Showing 23 changed files with 76 additions and 54 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ jobs:
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain"
- name: Test custom prng
run: |
JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py
examples:
Expand Down
5 changes: 2 additions & 3 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import operator

from jax import grad, jacfwd, numpy as jnp, random, vmap
from jax.random import KeyArray
from jax.tree_util import tree_map

from numpyro import handlers
Expand Down Expand Up @@ -370,10 +369,10 @@ def _update_force(attr_force, rep_force, jac):
)
return jnp.linalg.norm(particle_grads), res_grads

def init(self, rng_key: KeyArray, *args, **kwargs):
def init(self, rng_key, *args, **kwargs):
"""Register random variable transformations, constraints and determine initialize positions of the particles.
:param KeyArray rng_key: Random number generator seed.
:param rng_key: Random number generator seed.
:param args: Arguments to the model / guide.
:param kwargs: Keyword arguments to the model / guide.
:return: initial :data:`SteinVIState`
Expand Down
6 changes: 3 additions & 3 deletions numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numpyro.infer import init_to_uniform
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import initialize_model
from numpyro.util import identity
from numpyro.util import identity, is_prng_key

TFPKernelState = namedtuple("TFPKernelState", ["z", "kernel_results", "rng_key"])

Expand Down Expand Up @@ -174,7 +174,7 @@ def init(
self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}
):
# non-vectorized
if rng_key.ndim == 1:
if is_prng_key(rng_key):
rng_key, rng_key_init_model = random.split(rng_key)
# vectorized
else:
Expand All @@ -190,7 +190,7 @@ def init(
" `target_log_prob_fn`."
)

if rng_key.ndim == 1:
if is_prng_key(rng_key):
init_state = self._init_fn(init_params, rng_key)
else:
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
Expand Down
3 changes: 2 additions & 1 deletion numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
ZeroInflatedDistribution,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample
from numpyro.distributions.util import promote_shapes, validate_sample
from numpyro.util import is_prng_key


def _log_beta_1(alpha, value):
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@
betaincinv,
cholesky_of_inverse,
gammaincinv,
is_prng_key,
lazy_property,
matrix_to_tril_vec,
promote_shapes,
signed_stick_breaking_tril,
validate_sample,
vec_to_tril_matrix,
)
from numpyro.util import is_prng_key


class AsymmetricLaplace(Distribution):
Expand Down
8 changes: 2 additions & 6 deletions numpyro/distributions/copula.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@
import numpyro.distributions.constraints as constraints
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
clamp_probs,
is_prng_key,
lazy_property,
validate_sample,
)
from numpyro.distributions.util import clamp_probs, lazy_property, validate_sample
from numpyro.util import is_prng_key


class GaussianCopula(Distribution):
Expand Down
3 changes: 1 addition & 2 deletions numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
is_prng_key,
lazy_property,
promote_shapes,
safe_normalize,
validate_sample,
von_mises_centered,
)
from numpyro.util import while_loop
from numpyro.util import is_prng_key, while_loop


def _numel(shape):
Expand Down
3 changes: 1 addition & 2 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@
binomial,
categorical,
clamp_probs,
is_prng_key,
lazy_property,
multinomial,
promote_shapes,
validate_sample,
)
from numpyro.util import not_jax_tracer
from numpyro.util import is_prng_key, not_jax_tracer


def _to_probs_bernoulli(logits):
Expand Down
3 changes: 2 additions & 1 deletion numpyro/distributions/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from numpyro.distributions import Distribution, constraints
from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs
from numpyro.distributions.util import is_prng_key, validate_sample
from numpyro.distributions.util import validate_sample
from numpyro.util import is_prng_key


def Mixture(mixing_distribution, component_distributions, *, validate_args=None):
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
clamp_probs,
is_prng_key,
lazy_property,
promote_shapes,
validate_sample,
)
from numpyro.util import is_prng_key


class LeftTruncatedDistribution(Distribution):
Expand Down
7 changes: 4 additions & 3 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import namedtuple
from functools import partial, update_wrapper
import math
import warnings

import numpy as np

Expand Down Expand Up @@ -612,11 +613,11 @@ def safe_normalize(x, *, p=2):
return x


# src: https://github.com/google/jax/blob/5a41779fbe12ba7213cd3aa1169d3b0ffb02a094/jax/_src/random.py#L95
def is_prng_key(key):
if isinstance(key, jax.random.PRNGKeyArray):
return key.shape == ()
warnings.warn("Please use numpyro.util.is_prng_key.", DeprecationWarning)
try:
if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key):
return key.shape == ()
return key.shape == (2,) and key.dtype == np.uint32
except AttributeError:
return False
Expand Down
16 changes: 8 additions & 8 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
apply_stack,
plate,
)
from numpyro.util import find_stack_level, not_jax_tracer
from numpyro.util import find_stack_level, is_prng_key, not_jax_tracer

__all__ = [
"block",
Expand Down Expand Up @@ -705,15 +705,15 @@ class seed(Messenger):
"""

def __init__(self, fn=None, rng_seed=None, hide_types=None):
if isinstance(rng_seed, int) or (
isinstance(rng_seed, (np.ndarray, jnp.ndarray)) and not jnp.shape(rng_seed)
if not is_prng_key(rng_seed) and (
isinstance(rng_seed, int)
or (
isinstance(rng_seed, (np.ndarray, jnp.ndarray))
and not jnp.shape(rng_seed)
)
):
rng_seed = random.PRNGKey(rng_seed)
if not (
isinstance(rng_seed, (np.ndarray, jnp.ndarray))
and rng_seed.dtype == jnp.uint32
and rng_seed.shape == (2,)
):
if not is_prng_key(rng_seed):
raise TypeError("Incorrect type for rng_seed: {}".format(type(rng_seed)))
self.rng_key = rng_seed
self.hide_types = [] if hide_types is None else hide_types
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numpyro.infer.initialization import init_to_uniform
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import initialize_model
from numpyro.util import identity
from numpyro.util import identity, is_prng_key

BarkerMHState = namedtuple(
"BarkerMHState",
Expand Down Expand Up @@ -170,7 +170,7 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params):
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
self._num_warmup = num_warmup
# TODO (low-priority): support chain_method="vectorized", i.e. rng_key is a batch of keys
assert rng_key.shape == (2,), (
assert is_prng_key(rng_key), (
"BarkerMH only supports chain_method='parallel' or chain_method='sequential'."
" Please put in a feature request if you think it would be useful to be able "
"to use BarkerMH in vectorized mode."
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
from numpyro.util import cond, fori_loop, identity
from numpyro.util import cond, fori_loop, identity, is_prng_key

HMCState = namedtuple(
"HMCState",
Expand Down Expand Up @@ -703,7 +703,7 @@ def init(
self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}
):
# non-vectorized
if rng_key.ndim == 1:
if is_prng_key(rng_key):
rng_key, rng_key_init_model = random.split(rng_key)
# vectorized
else:
Expand Down Expand Up @@ -749,7 +749,7 @@ def init(
model_kwargs=model_kwargs,
rng_key=rng_key,
)
if rng_key.ndim == 1:
if is_prng_key(rng_key):
init_state = hmc_init_fn(init_params, rng_key)
else:
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
Expand Down
12 changes: 9 additions & 3 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from jax.tree_util import tree_flatten, tree_map

from numpyro.diagnostics import print_summary
from numpyro.util import cached_by, find_stack_level, fori_collect, identity
from numpyro.util import (
cached_by,
find_stack_level,
fori_collect,
identity,
is_prng_key,
)

__all__ = [
"MCMCKernel",
Expand Down Expand Up @@ -418,7 +424,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields):
sample_fn, postprocess_fn = self._get_cached_fns()
diagnostics = (
lambda x: self.sampler.get_diagnostics_str(x[0])
if rng_key.ndim == 1
if is_prng_key(rng_key)
else ""
) # noqa: E731
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
Expand Down Expand Up @@ -595,7 +601,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
self._args = args
self._kwargs = kwargs
init_state = self._get_cached_init_state(rng_key, args, kwargs)
if self.num_chains > 1 and rng_key.ndim == 1:
if self.num_chains > 1 and is_prng_key(rng_key):
rng_key = random.split(rng_key, self.num_chains)

if self._warmup_state is not None:
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numpyro.distributions.util import cholesky_update
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import init_to_uniform, initialize_model
from numpyro.util import identity
from numpyro.util import identity, is_prng_key


def _get_proposal_loc_and_scale(samples, loc, scale, new_sample):
Expand Down Expand Up @@ -331,7 +331,7 @@ def init(
self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}
):
# non-vectorized
if rng_key.ndim == 1:
if is_prng_key(rng_key):
rng_key, rng_key_init_model = random.split(rng_key)
# vectorized
else:
Expand All @@ -358,7 +358,7 @@ def init(
model_args=model_args,
model_kwargs=model_kwargs,
)
if rng_key.ndim == 1:
if is_prng_key(rng_key):
init_state = sa_init_fn(init_params, rng_key)
else:
init_state = vmap(sa_init_fn)(init_params, rng_key)
Expand Down
8 changes: 5 additions & 3 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from numpyro.util import (
_validate_model,
find_stack_level,
is_prng_key,
not_jax_tracer,
soft_vmap,
while_loop,
Expand Down Expand Up @@ -435,7 +436,7 @@ def _find_valid_params(rng_key, exit_early=False):
return (init_params, pe, z_grad), is_valid

# Handle possible vectorization
if rng_key.ndim == 1:
if is_prng_key(rng_key):
(init_params, pe, z_grad), is_valid = _find_valid_params(
rng_key, exit_early=True
)
Expand Down Expand Up @@ -644,7 +645,7 @@ def initialize_model(
"""
model_kwargs = {} if model_kwargs is None else model_kwargs
substituted_model = substitute(
seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
substitute_fn=init_strategy,
)
(
Expand Down Expand Up @@ -816,9 +817,10 @@ def single_prediction(val):
return {name: value for name, value in pred_samples.items() if name in sites}

num_samples = int(np.prod(batch_shape))
key_shape = rng_key.shape
if num_samples > 1:
rng_key = random.split(rng_key, num_samples)
rng_key = rng_key.reshape((*batch_shape, 2))
rng_key = rng_key.reshape(batch_shape + key_shape)
chunk_size = num_samples if parallel else 1
return soft_vmap(
single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
Expand Down
2 changes: 1 addition & 1 deletion numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax.experimental.pjit import pjit_p
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
import jax.linear_util as lu
import jax.extend.linear_util as lu
import jax.numpy as jnp


Expand Down
9 changes: 9 additions & 0 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ def fori_loop(lower, upper, body_fun, init_val):
return lax.fori_loop(lower, upper, body_fun, init_val)


def is_prng_key(key):
try:
if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key):
return key.shape == ()
return key.shape == (2,) and key.dtype == np.uint32
except AttributeError:
return False


def not_jax_tracer(x):
"""
Checks if `x` is not an array generated inside `jit`, `pmap`, `vmap`, or `lax_control_flow`.
Expand Down
4 changes: 2 additions & 2 deletions test/contrib/einstein/test_steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def model(obs):
expected_shape = (num_particles, *np.shape(inner_param["value"]))
assert init_value.shape == expected_shape
if "auto_loc" in name or name == "b":
assert np.alltrue(init_value != np.zeros(expected_shape))
assert np.all(init_value != np.zeros(expected_shape))
assert np.unique(init_value).shape == init_value.reshape(-1).shape
elif "scale" in name:
assert_allclose(init_value[init_value != 0.0], 0.1, rtol=1e-6)
Expand Down Expand Up @@ -311,7 +311,7 @@ def model(obs):
expected_shape = (num_particles, latent_dim)

assert expected_shape == init_value.shape
assert np.alltrue(init_value != np.zeros(expected_shape))
assert np.all(init_value != np.zeros(expected_shape))
assert np.unique(init_value).shape == init_value.reshape(-1).shape


Expand Down
Loading

0 comments on commit 115c4d3

Please sign in to comment.