Skip to content

Commit

Permalink
Restricted icdf function to [0,1]
Browse files Browse the repository at this point in the history
  • Loading branch information
TheSkyentist committed Dec 18, 2024
1 parent 5bb9f5f commit 054a435
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def sample(self, key, sample_shape=()):
def icdf(self, q):
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
ppf = (1 - sign) * loc + sign * self.base_dist.icdf(
(1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high
)
return jnp.where(q < 0, jnp.nan, ppf)

@validate_sample
def log_prob(self, value):
Expand Down Expand Up @@ -144,7 +145,8 @@ def sample(self, key, sample_shape=()):
return self.icdf(u)

def icdf(self, q):
return self.base_dist.icdf(q * self._cdf_at_high)
ppf = self.base_dist.icdf(q * self._cdf_at_high)
return jnp.where(q > 1, jnp.nan, ppf)

@validate_sample
def log_prob(self, value):
Expand Down Expand Up @@ -253,9 +255,10 @@ def icdf(self, q):
# A = 2 * loc - icdf[(1 - q) * cdf(2*loc-low)) + q * cdf(2*loc - high)]
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
ppf = (1 - sign) * loc + sign * self.base_dist.icdf(
clamp_probs((1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high)
)
return jnp.where(jnp.logical_or(q < 0, q > 1), jnp.nan, ppf)

@validate_sample
def log_prob(self, value):
Expand Down

0 comments on commit 054a435

Please sign in to comment.