diff --git a/docs/source/utilities.rst b/docs/source/utilities.rst index 1fc7ad7c7..3cec9aa0b 100644 --- a/docs/source/utilities.rst +++ b/docs/source/utilities.rst @@ -40,9 +40,9 @@ log_density ----------- .. autofunction:: numpyro.infer.util.log_density -log_densities -------------- -.. autofunction:: numpyro.infer.util.log_densities +compute_log_probs +----------------- +.. autofunction:: numpyro.infer.util.compute_log_probs get_transforms -------------- diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 478971228..a3b7425d0 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -54,7 +54,7 @@ def process_message(self, msg): msg["value"] = random.PRNGKey(0) -def log_densities( +def compute_log_probs( model, model_args: tuple, model_kwargs: dict, @@ -116,7 +116,7 @@ def log_density(model, model_args: tuple, model_kwargs: dict, params: dict): :param params: Dictionary of current parameter values keyed by site name. :return: Log of joint density and a corresponding model trace. """ - log_joint, model_trace = log_densities(model, model_args, model_kwargs, params) + log_joint, model_trace = compute_log_probs(model, model_args, model_kwargs, params) # We need to start with 0.0 instead of 0 because log_joint may be empty or only # contain integers, but log_density must be a floating point value to be # differentiable by jax. diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index 0505f16fa..ab0133700 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -27,9 +27,9 @@ from numpyro.infer.reparam import TransformReparam from numpyro.infer.util import ( Predictive, + compute_log_probs, constrain_fn, initialize_model, - log_densities, log_density, log_likelihood, potential_energy, @@ -268,7 +268,7 @@ def test_log_likelihood(batch_shape): ) -def test_log_densities(): +def test_compute_log_probs(): model, data, _ = beta_bernoulli() samples = Predictive(model, return_sites=["beta"], num_samples=1)(random.key(7)) samples = {key: value[0] for key, value in samples.items()} @@ -276,11 +276,11 @@ def test_log_densities(): logden, _ = log_density(model, (data,), {}, samples) assert logden.shape == () - logdens, _ = log_densities(model, (data,), {}, samples) + logdens, _ = compute_log_probs(model, (data,), {}, samples) assert set(logdens) == {"beta", "obs"} assert all(x.shape == () for x in logdens.values()) - logdens, _ = log_densities(model, (data,), {}, samples, False) + logdens, _ = compute_log_probs(model, (data,), {}, samples, False) assert logdens["beta"].shape == (2,) assert logdens["obs"].shape == (800, 2)