Skip to content

Commit

Permalink
Rename random_state -> rng_key in numpyro.sample (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fehiepsi committed Nov 1, 2019
1 parent 9b61538 commit a2e5990
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _validate_sample(self, value):
return mask

def __call__(self, *args, **kwargs):
key = kwargs.pop('random_state')
key = kwargs.pop('rng_key')
sample_intermediates = kwargs.pop('sample_intermediates', False)
if sample_intermediates:
return self.sample_with_intermediates(key, *args, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class trace(Messenger):
{'args': (),
'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>,
'is_observed': False,
'kwargs': {'random_state': DeviceArray([0, 0], dtype=uint32)},
'kwargs': {'rng_key': DeviceArray([0, 0], dtype=uint32)},
'name': 'a',
'type': 'sample',
'value': DeviceArray(-0.20584235, dtype=float32)})])
Expand Down Expand Up @@ -362,9 +362,9 @@ def __init__(self, fn=None, rng_seed=None, rng=None):

def process_message(self, msg):
if msg['type'] == 'sample' and not msg['is_observed'] and \
msg['kwargs']['random_state'] is None:
msg['kwargs']['rng_key'] is None:
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg['kwargs']['random_state'] = rng_key_sample
msg['kwargs']['rng_key'] = rng_key_sample


class substitute(Messenger):
Expand Down
10 changes: 5 additions & 5 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)


def sample(name, fn, obs=None, random_state=None, sample_shape=()):
def sample(name, fn, obs=None, rng_key=None, sample_shape=()):
"""
Returns a random sample from the stochastic function `fn`. This can have
additional side effects when wrapped inside effect handlers like
Expand All @@ -69,27 +69,27 @@ def sample(name, fn, obs=None, random_state=None, sample_shape=()):
.. note::
By design, `sample` primitive is meant to be used inside a NumPyro model.
Then :class:`~numpyro.handlers.seed` handler is used to inject a random
state to `fn`. In those situations, `random_state` keyword will take no
state to `fn`. In those situations, `rng_key` keyword will take no
effect.
:param str name: name of the sample site
:param fn: Python callable
:param numpy.ndarray obs: observed value
:param jax.random.PRNGKey random_state: an optional random key for `fn`.
:param jax.random.PRNGKey rng_key: an optional random key for `fn`.
:param sample_shape: Shape of samples to be drawn.
:return: sample from the stochastic `fn`.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
return fn(random_state=random_state, sample_shape=sample_shape)
return fn(rng_key=rng_key, sample_shape=sample_shape)

# Otherwise, we initialize a message...
initial_msg = {
'type': 'sample',
'name': name,
'fn': fn,
'args': (),
'kwargs': {'random_state': random_state, 'sample_shape': sample_shape},
'kwargs': {'rng_key': rng_key, 'sample_shape': sample_shape},
'value': obs,
'scale': 1.0,
'is_observed': obs is not None,
Expand Down

0 comments on commit a2e5990

Please sign in to comment.