Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility function to evaluate log density for individual sites. #1932

Merged
merged 4 commits into from
Dec 6, 2024

Conversation

tillahoffmann
Copy link
Contributor

@tillahoffmann tillahoffmann commented Dec 5, 2024

This PR adds a utility function log_densities to the infer.util module to evaluate log densities for individual sites of the model. This is mostly useful for debugging and inspecting individual contributions to the summed log density and ELBO.

I've moved type annotations from the docstring to the signature for log_density which now calls log_densities but otherwise remains unchanged.

@@ -54,21 +54,28 @@ def process_message(self, msg):
msg["value"] = random.PRNGKey(0)


def log_density(model, model_args, model_kwargs, params):
def log_densities(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe compute_log_probs?

@tillahoffmann
Copy link
Contributor Author

I think the failing tests might be related to the version change of jax from 0.4.35 to 0.4.36.

@fehiepsi
Copy link
Member

fehiepsi commented Dec 6, 2024

Could you make a github issue and mark test_haiku_state_dropout_smoke as xfail?

@fehiepsi fehiepsi merged commit bf9c715 into pyro-ppl:master Dec 6, 2024
9 checks passed
@fehiepsi
Copy link
Member

fehiepsi commented Dec 6, 2024

Thanks, @tillahoffmann!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants