Skip to content

Commit

Permalink
Put normaliziation compute in initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Dec 4, 2023
1 parent 43b6525 commit cb8d95b
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ class Powerlaw(Prior):
xmin: float = 0.0
xmax: float = 1.0
alpha: int = 0.0
normalization: float = 1.0

def __init__(
self,
Expand All @@ -371,6 +372,11 @@ def __init__(
self.xmax = xmax
self.xmin = xmin
self.alpha = alpha
if alpha == -1:
self.normalization = 1. / jnp.log(self.xmax / self.xmin)
else:
self.normalization = (1 + self.alpha) / (self.xmax ** (1 + self.alpha) -
self.xmin ** (1 + self.alpha))

def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict:
"""
Expand Down Expand Up @@ -401,17 +407,12 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict:

def log_prob(self, x: dict) -> Float:
variable = x[self.naming[0]]
if self.alpha == -1:
normalization_constant = 1. / jnp.log(self.xmax / self.xmin)
else:
normalization_constant = (1 + self.alpha) / (self.xmax ** (1 + self.alpha) -
self.xmin ** (1 + self.alpha))
log_in_range = jnp.where(
(variable >= self.xmax) | (variable <= self.xmin),
jnp.zeros_like(variable) - jnp.inf,
jnp.zeros_like(variable),
)
log_p = self.alpha * jnp.log(variable) + jnp.log(normalization_constant)
log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization)
return log_p + log_in_range


Expand Down

0 comments on commit cb8d95b

Please sign in to comment.