Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ICDF Methods for Truncated Distributions #1938

Merged
merged 2 commits into from
Dec 18, 2024

Conversation

TheSkyentist
Copy link
Contributor

Implementation of generic Inverse Cumulative Distribution Function (ICDF) methods for Truncated Distributions.

Sampling method for Truncated Distribution already calculates the ICDF, so the initial commit is refactoring the code into its own method. Still needs to be tested so starting this as a draft.

Motivated by #1937 for use with Nested Sampling with jaxns.

@TheSkyentist
Copy link
Contributor Author

Fixed an issue where the icdf functions would return non-NaN values for inputs outside of the range [0,1], although the values in the range are correct. Now the icdf function returns a NaN for inputs outside the range [0,1], which makes it consistent with other distributions as well.

@TheSkyentist
Copy link
Contributor Author

All formatting and relevant distribution tests pass! Ready for review.

@TheSkyentist TheSkyentist marked this pull request as ready for review December 18, 2024 09:25
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @TheSkyentist Could you double check if our tests already cover the change?


def icdf(self, q):
ppf = self.base_dist.icdf(q * self._cdf_at_high)
return jnp.where(q > 1, jnp.nan, ppf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to cover q > 1?

numpyro/distributions/truncated.py Show resolved Hide resolved
)
return jnp.where(jnp.logical_or(q < 0, q > 1), jnp.nan, ppf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like above, why this where is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment :-)

@TheSkyentist
Copy link
Contributor Author

I think the tests cover the change at least as far as I can see, but perhaps I'm not sure what you mean, could you clarify slightly? Thanks :-)

@fehiepsi fehiepsi merged commit 8e1d9b2 into pyro-ppl:master Dec 18, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants