Skip to content

Commit

Permalink
Minor changes to infer.util for 0.2.2 (#487)
Browse files Browse the repository at this point in the history
* Minor changes to infer.util for 0.2.2

* fix test

* fix test; address comment

* fix invocation
  • Loading branch information
neerajprad authored and fehiepsi committed Dec 4, 2019
1 parent db872b1 commit cc777e8
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 44 deletions.
2 changes: 1 addition & 1 deletion examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main(args):
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
init_params, potential_fn, constrain_fn = initialize_model(init_rng_key, model, returns)
init_params, potential_fn, constrain_fn = initialize_model(init_rng_key, model, model_args=(returns,))
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup, rng_key=sample_rng_key)
hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state,
Expand Down
5 changes: 3 additions & 2 deletions numpyro/contrib/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ def __init__(self, model, prefix="auto", init_strategy=init_to_uniform()):
def _setup_prototype(self, *args, **kwargs):
super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix), dist.PRNGIdentity())
init_params, _ = handlers.block(find_valid_initial_params)(rng_key, self.model, *args,
init_params, _ = handlers.block(find_valid_initial_params)(rng_key, self.model,
init_strategy=self.init_strategy,
**kwargs)
model_args=args,
model_kwargs=kwargs)
self._inv_transforms = {}
self._has_transformed_dist = False
unconstrained_sites = {}
Expand Down
7 changes: 4 additions & 3 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'):
... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
>>>
>>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0),
... model, data, labels)
... model, model_args=(data, labels,))
>>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
>>> hmc_state = init_kernel(init_params,
... trajectory_length=10,
Expand Down Expand Up @@ -495,10 +495,11 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg
' `potential_fn`.')
# Find valid initial params
if self._model and not init_params:
init_params, is_valid = find_valid_initial_params(rng_key, self._model, *model_args,
init_params, is_valid = find_valid_initial_params(rng_key, self._model,
init_strategy=self._init_strategy,
param_as_improper=True,
**model_kwargs)
model_args=model_args,
model_kwargs=model_kwargs)
if not_jax_tracer(is_valid):
if device_get(~np.all(is_valid)):
raise RuntimeError("Cannot find valid initial parameters. "
Expand Down
67 changes: 38 additions & 29 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@

def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False):
"""
Computes log of joint density for the model given latent values ``params``.
(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
latent values ``params``.
:param model: Python callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs`: kwargs provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:param bool skip_dist_transforms: whether to compute log probability of a site
Expand Down Expand Up @@ -76,8 +77,9 @@ def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=Fa

def transform_fn(transforms, params, invert=False):
"""
Callable that applies a transformation from the `transforms` dict to values in the
`params` dict and returns the transformed values keyed on the same names.
(EXPERIMENTAL INTERFACE) Callable that applies a transformation from the `transforms`
dict to values in the `params` dict and returns the transformed values keyed on
the same names.
:param transforms: Dictionary of transforms keyed by names. Names in
`transforms` and `params` should align.
Expand All @@ -93,17 +95,18 @@ def transform_fn(transforms, params, invert=False):

def constrain_fn(model, transforms, model_args, model_kwargs, params):
"""
Gets value at each latent site in `model` given unconstrained parameters `params`.
The `transforms` is used to transform these unconstrained parameters to base values
of the corresponding priors in `model`. If a prior is a transformed distribution,
the corresponding base value lies in the support of base distribution. Otherwise,
the base value lies in the support of the distribution.
(EXPERIMENTAL INTERFACE) Gets value at each latent site in `model` given
unconstrained parameters `params`. The `transforms` is used to transform these
unconstrained parameters to base values of the corresponding priors in `model`.
If a prior is a transformed distribution, the corresponding base value lies in
the support of base distribution. Otherwise, the base value lies in the support
of the distribution.
:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict transforms: dictionary of transforms keyed by names. Names in
`transforms` and `params` should align.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of unconstrained values keyed by site
names.
:return: `dict` of transformed params.
Expand All @@ -116,16 +119,16 @@ def constrain_fn(model, transforms, model_args, model_kwargs, params):

def potential_energy(model, inv_transforms, model_args, model_kwargs, params):
"""
Computes potential energy of a model given unconstrained params.
(EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params.
The `inv_transforms` is used to transform these unconstrained parameters to base values
of the corresponding priors in `model`. If a prior is a transformed distribution,
the corresponding base value lies in the support of base distribution. Otherwise,
the base value lies in the support of the distribution.
:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs`: kwargs provided to the model.
:param dict inv_transforms: dictionary of transforms keyed by names.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: unconstrained parameters of `model`.
:return: potential energy given unconstrained parameters.
"""
Expand Down Expand Up @@ -268,8 +271,11 @@ def init_to_value(values):
return partial(_init_to_value, values=values)


def find_valid_initial_params(rng_key, model, *model_args, init_strategy=init_to_uniform(),
param_as_improper=False, **model_kwargs):
def find_valid_initial_params(rng_key, model,
init_strategy=init_to_uniform(),
param_as_improper=False,
model_args=(),
model_kwargs=None):
"""
(EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial
valid unconstrained value for all the parameters. This function also returns an
Expand All @@ -281,11 +287,11 @@ def find_valid_initial_params(rng_key, model, *model_args, init_strategy=init_to
sample from the prior. The returned `init_params` will have the
batch shape ``rng_key.shape[:-1]``.
:param model: Python callable containing Pyro primitives.
:param `*model_args`: args provided to the model.
:param callable init_strategy: a per-site initialization function.
:param bool param_as_improper: a flag to decide whether to consider sites with
`param` statement as sites with improper priors.
:param `**model_kwargs`: kwargs provided to the model.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:return: tuple of (`init_params`, `is_valid`).
"""
init_strategy = jax.partial(init_strategy, skip_param=not param_as_improper)
Expand Down Expand Up @@ -416,8 +422,11 @@ def constrain_fun(*args, **kwargs):
return potential_fn, constrain_fun


def initialize_model(rng_key, model, *model_args, init_strategy=init_to_uniform(),
dynamic_args=False, **model_kwargs):
def initialize_model(rng_key, model,
init_strategy=init_to_uniform(),
dynamic_args=False,
model_args=(),
model_kwargs=None):
"""
(EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
Expand All @@ -427,30 +436,33 @@ def initialize_model(rng_key, model, *model_args, init_strategy=init_to_uniform(
sample from the prior. The returned `init_params` will have the
batch shape ``rng_key.shape[:-1]``.
:param model: Python callable containing Pyro primitives.
:param `*model_args`: args provided to the model.
:param callable init_strategy: a per-site initialization function.
See :ref:`init_strategy` section for available functions.
:param bool dynamic_args: if `True`, the `potential_fn` and
`constraints_fn` are themselves dependent on model arguments.
When provided a `*model_args, **model_kwargs`, they return
`potential_fn` and `constraints_fn` callables, respectively.
:param `**model_kwargs`: kwargs provided to the model.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:return: tuple of (`init_params`, `potential_fn`, `constrain_fn`),
`init_params` are values from the prior used to initiate MCMC,
`constrain_fn` is a callable that uses inverse transforms
to convert unconstrained HMC samples to constrained values that
lie within the site's support.
"""
if model_kwargs is None:
model_kwargs = {}
potential_fun, constrain_fun = get_potential_fn(rng_key if rng_key.ndim == 1 else rng_key[0],
model,
dynamic_args=dynamic_args,
model_args=model_args,
model_kwargs=model_kwargs)

init_params, is_valid = find_valid_initial_params(rng_key, model, *model_args,
init_params, is_valid = find_valid_initial_params(rng_key, model,
init_strategy=init_strategy,
param_as_improper=True,
**model_kwargs)
model_args=model_args,
model_kwargs=model_kwargs)

if not_jax_tracer(is_valid):
if device_get(~np.all(is_valid)):
Expand Down Expand Up @@ -559,11 +571,8 @@ def get_samples(self, rng_key, *args, **kwargs):

def log_likelihood(model, posterior_samples, *args, **kwargs):
"""
Returns log likelihood at observation nodes of model, given samples of all latent variables.
.. warning::
The interface for the `log_likelihood` function is experimental, and
might change in the future.
(EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model,
given samples of all latent variables.
:param model: Python callable containing Pyro primitives.
:param dict posterior_samples: dictionary of samples from the posterior.
Expand Down
20 changes: 12 additions & 8 deletions test/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,13 @@ def model(data):
])

rng_keys = random.split(random.PRNGKey(1), 2)
init_params, _, _ = initialize_model(rng_keys, model, count_data,
init_strategy=init_strategy)
init_params, _, _ = initialize_model(rng_keys, model,
init_strategy=init_strategy,
model_args=(count_data,))
for i in range(2):
init_params_i, _, _ = initialize_model(rng_keys[i], model, count_data,
init_strategy=init_strategy)
init_params_i, _, _ = initialize_model(rng_keys[i], model,
init_strategy=init_strategy,
model_args=(count_data,))
for name, p in init_params.items():
# XXX: the result is equal if we disable fast-math-mode
assert_allclose(p[i], init_params_i[name], atol=1e-6)
Expand All @@ -219,11 +221,13 @@ def model(data):
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))

rng_keys = random.split(random.PRNGKey(1), 2)
init_params, _, _ = initialize_model(rng_keys, model, data,
init_strategy=init_strategy)
init_params, _, _ = initialize_model(rng_keys, model,
init_strategy=init_strategy,
model_args=(data,))
for i in range(2):
init_params_i, _, _ = initialize_model(rng_keys[i], model, data,
init_strategy=init_strategy)
init_params_i, _, _ = initialize_model(rng_keys[i], model,
init_strategy=init_strategy,
model_args=(data,))
for name, p in init_params.items():
# XXX: the result is equal if we disable fast-math-mode
assert_allclose(p[i], init_params_i[name], atol=1e-6)
2 changes: 1 addition & 1 deletion test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def model(data):

true_probs = np.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data)
init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, model_args=(data,))
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
hmc_state = init_kernel(init_params,
trajectory_length=1.,
Expand Down

0 comments on commit cc777e8

Please sign in to comment.