Skip to content

Commit

Permalink
Adding Powerlaw prior class
Browse files Browse the repository at this point in the history
  • Loading branch information
tsunhopang committed Dec 4, 2023
1 parent e59fa02 commit ee7bc8a
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,79 @@ def log_prob(self, x: dict) -> Float:
return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]]))


class Powerlaw(Prior):

"""
A prior following the power-law with alpha in the range [xmin, xmax).
p(x) ~ x^{\alpha}
"""

xmin: float = 0.0
xmax: float = 1.0
alpha: int = 0.0

def __init__(
self,
xmin: float,
xmax: float,
alpha: int | float,
naming: list[str],
transforms: dict[tuple[str, Callable]] = {},
):
super().__init__(naming, transforms)
assert isinstance(xmin, float), "xmin must be a float"
assert isinstance(xmax, float), "xmax must be a float"
assert isinstance(alpha, (int, float)), "alpha must be a int or a float"
if alpha < 0.:
assert alpha < 0. or xmin > 0., "With negative alpha, xmin must > 0"
assert self.n_dim == 1, "Powerlaw needs to be 1D distributions"
self.xmax = xmax
self.xmin = xmin
self.alpha = alpha

def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict:
"""
Sample from a power-law distribution.
Parameters
----------
rng_key : jax.random.PRNGKey
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
q_samples = jax.random.uniform(
rng_key, (n_samples,), minval=0., maxval=1.
)
if self.alpha == -1:
samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin))
else:
samples = (self.xmin ** (1. + self.alpha) + q_samples *
(self.xmax ** (1. + self.alpha) - self.xmin ** (1. + self.alpha))) ** (1. / (1. + self.alpha))
return self.add_name(samples[None])

def log_prob(self, x: dict) -> Float:
variable = x[self.naming[0]]
if self.alpha == -1:
normalization_constant = 1. / jnp.log(self.xmax / self.xmin)
else:
normalization_constant = (1 + self.alpha) / (self.xmax ** (1 + self.alpha) -
self.xmin ** (1 + self.alpha))
log_in_range = jnp.where(
(variable >= self.xmax) | (variable <= self.xmin),
jnp.zeros_like(variable) - jnp.inf,
jnp.zeros_like(variable),
)
log_p = self.alpha * jnp.log(variable) + jnp.log(normalization_constant)
return log_p + log_in_range


class Composite(Prior):

priors: list[Prior] = field(default_factory=list)
Expand Down

0 comments on commit ee7bc8a

Please sign in to comment.