Skip to content

Commit

Permalink
Rename log_densities to compute_log_probs.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Dec 5, 2024
1 parent ac89884 commit f287751
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions docs/source/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ log_density
-----------
.. autofunction:: numpyro.infer.util.log_density

log_densities
-------------
.. autofunction:: numpyro.infer.util.log_densities
compute_log_probs
-----------------
.. autofunction:: numpyro.infer.util.compute_log_probs

get_transforms
--------------
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def process_message(self, msg):
msg["value"] = random.PRNGKey(0)


def log_densities(
def compute_log_probs(
model,
model_args: tuple,
model_kwargs: dict,
Expand Down Expand Up @@ -116,7 +116,7 @@ def log_density(model, model_args: tuple, model_kwargs: dict, params: dict):
:param params: Dictionary of current parameter values keyed by site name.
:return: Log of joint density and a corresponding model trace.
"""
log_joint, model_trace = log_densities(model, model_args, model_kwargs, params)
log_joint, model_trace = compute_log_probs(model, model_args, model_kwargs, params)
# We need to start with 0.0 instead of 0 because log_joint may be empty or only
# contain integers, but log_density must be a floating point value to be
# differentiable by jax.
Expand Down
8 changes: 4 additions & 4 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from numpyro.infer.reparam import TransformReparam
from numpyro.infer.util import (
Predictive,
compute_log_probs,
constrain_fn,
initialize_model,
log_densities,
log_density,
log_likelihood,
potential_energy,
Expand Down Expand Up @@ -268,19 +268,19 @@ def test_log_likelihood(batch_shape):
)


def test_log_densities():
def test_compute_log_probs():
model, data, _ = beta_bernoulli()
samples = Predictive(model, return_sites=["beta"], num_samples=1)(random.key(7))
samples = {key: value[0] for key, value in samples.items()}

logden, _ = log_density(model, (data,), {}, samples)
assert logden.shape == ()

logdens, _ = log_densities(model, (data,), {}, samples)
logdens, _ = compute_log_probs(model, (data,), {}, samples)
assert set(logdens) == {"beta", "obs"}
assert all(x.shape == () for x in logdens.values())

logdens, _ = log_densities(model, (data,), {}, samples, False)
logdens, _ = compute_log_probs(model, (data,), {}, samples, False)
assert logdens["beta"].shape == (2,)
assert logdens["obs"].shape == (800, 2)

Expand Down

0 comments on commit f287751

Please sign in to comment.