From c0008a517aff7c339dff037b14bb5d73772b904e Mon Sep 17 00:00:00 2001 From: Eike Petersen Date: Thu, 9 Nov 2023 19:26:41 +0100 Subject: [PATCH 1/2] Minor Kumaraswamy dist bug fixes --- numpyro/distributions/continuous.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index f5b7b81e6..9b1fd6b2d 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -765,7 +765,7 @@ def icdf(self, q): return self.loc - self.scale * jnp.log(-jnp.log(q)) -class Kumaraswamy(TransformedDistribution): +class Kumaraswamy(Distribution): arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, @@ -786,13 +786,7 @@ def __init__(self, concentration1, concentration0, *, validate_args=None): batch_shape = lax.broadcast_shapes( jnp.shape(concentration1), jnp.shape(concentration0) ) - base_dist = Uniform(0, 1).expand(batch_shape) - transforms = [ - PowerTransform(1 / concentration0), - AffineTransform(1, -1), - PowerTransform(1 / concentration1), - ] - super().__init__(base_dist, transforms, validate_args=validate_args) + super().__init__(batch_shape=batch_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): assert is_prng_key(key) @@ -803,7 +797,7 @@ def sample(self, key, sample_shape=()): return jnp.clip(jnp.exp(log_sample), a_min=finfo.tiny, a_max=1 - finfo.eps) @validate_sample - def log_prob(self, value): + def log_prob(self, value, intermediates=None): normalize_term = jnp.log(self.concentration0) + jnp.log(self.concentration1) return ( xlogy(self.concentration1 - 1, value) From 4fea01727d4df3728f6d7625c68c844c7257c809 Mon Sep 17 00:00:00 2001 From: Eike Petersen <1774207+e-pet@users.noreply.github.com> Date: Thu, 9 Nov 2023 21:55:30 +0100 Subject: [PATCH 2/2] Removing intermediates from Kuma log_prob again because no longer necessary --- numpyro/distributions/continuous.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 9b1fd6b2d..e8cff6ca6 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -797,7 +797,7 @@ def sample(self, key, sample_shape=()): return jnp.clip(jnp.exp(log_sample), a_min=finfo.tiny, a_max=1 - finfo.eps) @validate_sample - def log_prob(self, value, intermediates=None): + def log_prob(self, value): normalize_term = jnp.log(self.concentration0) + jnp.log(self.concentration1) return ( xlogy(self.concentration1 - 1, value)