From ee7bc8aa2d4e338e0f62edbf90caec7b42f2b5bb Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 4 Dec 2023 15:50:22 +0100 Subject: [PATCH] Adding Powerlaw prior class --- src/jimgw/prior.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 2137b018..b58436b8 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -240,6 +240,79 @@ def log_prob(self, x: dict) -> Float: return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) +class Powerlaw(Prior): + + """ + A prior following the power-law with alpha in the range [xmin, xmax). + p(x) ~ x^{\alpha} + """ + + xmin: float = 0.0 + xmax: float = 1.0 + alpha: int = 0.0 + + def __init__( + self, + xmin: float, + xmax: float, + alpha: int | float, + naming: list[str], + transforms: dict[tuple[str, Callable]] = {}, + ): + super().__init__(naming, transforms) + assert isinstance(xmin, float), "xmin must be a float" + assert isinstance(xmax, float), "xmax must be a float" + assert isinstance(alpha, (int, float)), "alpha must be a int or a float" + if alpha < 0.: + assert alpha < 0. or xmin > 0., "With negative alpha, xmin must > 0" + assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + """ + Sample from a power-law distribution. + + Parameters + ---------- + rng_key : jax.random.PRNGKey + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + q_samples = jax.random.uniform( + rng_key, (n_samples,), minval=0., maxval=1. + ) + if self.alpha == -1: + samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) + else: + samples = (self.xmin ** (1. + self.alpha) + q_samples * + (self.xmax ** (1. + self.alpha) - self.xmin ** (1. + self.alpha))) ** (1. / (1. + self.alpha)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict) -> Float: + variable = x[self.naming[0]] + if self.alpha == -1: + normalization_constant = 1. / jnp.log(self.xmax / self.xmin) + else: + normalization_constant = (1 + self.alpha) / (self.xmax ** (1 + self.alpha) - + self.xmin ** (1 + self.alpha)) + log_in_range = jnp.where( + (variable >= self.xmax) | (variable <= self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), + ) + log_p = self.alpha * jnp.log(variable) + jnp.log(normalization_constant) + return log_p + log_in_range + + class Composite(Prior): priors: list[Prior] = field(default_factory=list)