Skip to content

Commit

Permalink
Adding Alignedspin prior
Browse files Browse the repository at this point in the history
  • Loading branch information
tsunhopang committed Dec 4, 2023
1 parent ee7bc8a commit 43b6525
Showing 1 changed file with 102 additions and 0 deletions.
102 changes: 102 additions & 0 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,108 @@ def log_prob(self, x: dict) -> Float:
return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]]))


class Alignedspin(Prior):

"""
Prior distribution for the aligned (z) component of the spin.
This assume the prior distribution on the spin magnitude to be uniform in [0, amax]
with its orientation uniform on a sphere
p(chi) = -log(|chi| / amax) / 2 / amax
This is useful when comparing results between an aligned-spin run and
a precessing spin run.
See (A7) of https://arxiv.org/abs/1805.10457.
"""

amax: float = 0.99
chi_axis: Array = jnp.linspace(0, 1, num=1000)
cdf_vals: Array = jnp.linspace(0, 1, num=1000)

def __init__(
self,
amax: float,
naming: list[str],
transforms: dict[tuple[str, Callable]] = {},
):
super().__init__(naming, transforms)
assert isinstance(amax, float), "xmin must be a float"
assert self.n_dim == 1, "Alignedspin needs to be 1D distributions"
self.amax = amax

# build the interpolation table for the ppf of the one-sided distribution
chi_axis = jnp.linspace(1e-31, self.amax, num=1000)
cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.) / self.amax
self.chi_axis = chi_axis
self.cdf_vals = cdf_vals

def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict:
"""
Sample from the Alignedspin distribution.
for chi > 0;
p(chi) = -log(chi / amax) / amax # halved normalization constant
cdf(chi) = -chi * (log(chi / amax) - 1) / amax
Since there is a pole at chi=0, we will sample with the following steps
1. Map the samples with quantile > 0.5 to positive chi and negative otherwise
2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q)
2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5)
3. Map the quantile to chi via the ppf by checking against the table
built during the initialization
4. add back the sign
Parameters
----------
rng_key : jax.random.PRNGKey
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
q_samples = jax.random.uniform(
rng_key, (n_samples,), minval=0., maxval=1.
)
# 1. calculate the sign of chi from the q_samples
sign_samples = jnp.where(
q_samples >= 0.5,
jnp.zeros_like(q_samples) + 1.,
jnp.zeros_like(q_samples) - 1.,
)
# 2. remap q_samples
q_samples = jnp.where(
q_samples >=0.5,
2 * (q_samples - 0.5),
2 * (0.5 - q_samples),
)
# 3. map the quantile to chi via interpolation
samples = jnp.interp(
q_samples,
self.cdf_vals,
self.chi_axis,
)
# 4. add back the sign
samples *= sign_samples

return self.add_name(samples[None])

def log_prob(self, x: dict) -> Float:
variable = x[self.naming[0]]
log_p = jnp.where(
(variable >= self.amax) | (variable <= -self.amax),
jnp.zeros_like(variable) - jnp.inf,
jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2. / self.amax),
)
return log_p


class Powerlaw(Prior):

"""
Expand Down

0 comments on commit 43b6525

Please sign in to comment.