Skip to content

Commit

Permalink
Implement ICDF Methods for Truncated Distributions (#1938)
Browse files Browse the repository at this point in the history
* Add icdf methods to generic truncated distributions

* Restricted icdf function to [0,1]
  • Loading branch information
TheSkyentist authored Dec 18, 2024
1 parent 4b33db1 commit 8e1d9b2
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ def sample(self, key, sample_shape=()):
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
return self.icdf(u)

def icdf(self, q):
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
(1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
ppf = (1 - sign) * loc + sign * self.base_dist.icdf(
(1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high
)
return jnp.where(q < 0, jnp.nan, ppf)

@validate_sample
def log_prob(self, value):
Expand Down Expand Up @@ -138,7 +142,11 @@ def sample(self, key, sample_shape=()):
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
return self.base_dist.icdf(u * self._cdf_at_high)
return self.icdf(u)

def icdf(self, q):
ppf = self.base_dist.icdf(q * self._cdf_at_high)
return jnp.where(q > 1, jnp.nan, ppf)

@validate_sample
def log_prob(self, value):
Expand Down Expand Up @@ -235,19 +243,22 @@ def sample(self, key, sample_shape=()):
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
return self.icdf(u)

def icdf(self, q):
# NB: we use a more numerically stable formula for a symmetric base distribution
# A = icdf(cdf(low) + (cdf(high) - cdf(low)) * u) = icdf[(1 - u) * cdf(low) + u * cdf(high)]
# A = icdf(cdf(low) + (cdf(high) - cdf(low)) * q) = icdf[(1 - q) * cdf(low) + q * cdf(high)]
# will suffer by precision issues when low is large;
# If low < loc:
# A = icdf[(1 - u) * cdf(low) + u * cdf(high)]
# A = icdf[(1 - q) * cdf(low) + q * cdf(high)]
# Else
# A = 2 * loc - icdf[(1 - u) * cdf(2*loc-low)) + u * cdf(2*loc - high)]
# A = 2 * loc - icdf[(1 - q) * cdf(2*loc-low)) + q * cdf(2*loc - high)]
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
clamp_probs((1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high)
ppf = (1 - sign) * loc + sign * self.base_dist.icdf(
clamp_probs((1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high)
)
return jnp.where(jnp.logical_or(q < 0, q > 1), jnp.nan, ppf)

@validate_sample
def log_prob(self, value):
Expand Down

0 comments on commit 8e1d9b2

Please sign in to comment.