Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prior uniform in m1-m2 space with a bound in chirp mass and mass ratio #164

Closed
wants to merge 8 commits into from

Conversation

tsunhopang
Copy link
Collaborator

This PR is to have the prior defined as uniform in component mass, while the bound is defined by chirp mass and mass ratio.
Although using the bound-to-unbound in chirp mass and mass ratio space and the current initialization procedure would solve the sampling side problem, the evidence estimation will be off due to a constant shift in the log posterior. This PR is to fix this issue.

The following test script is used, with the scatter plot of the samples shown.

test_m1_m2

import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import CombinePrior, UniformPrior
from jimgw.single_event.likelihood import ZeroLikelihood
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import UniformInComponentMassSecondaryMassTransform
from jimgw.single_event.utils import Mc_m1_to_m2
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)

m1_prior = UniformPrior(6.0, 53.0, parameter_names=['m_1'])
m2_q_prior = UniformPrior(0.0, 1.0, parameter_names=['m_2_quantile'])

prior = CombinePrior(
    [
        m1_prior,
        m2_q_prior,
    ]
)


sample_transforms = [
    # all the bound-to-unbound transform
    BoundToUnbound(
        name_mapping = (["m_1"], ["m_1_unbounded"]),
        original_lower_bound=m1_prior.xmin, original_upper_bound=m1_prior.xmax
    ),
    BoundToUnbound(
        name_mapping = (["m_2_quantile"], ["m_2_quantile_unbounded"]),
        original_lower_bound=m2_q_prior.xmin, original_upper_bound=m2_q_prior.xmax
    ),
]

likelihood_transforms = [
    UniformInComponentMassSecondaryMassTransform(
        q_min=0.125, q_max=1.0,
        M_c_min=5.0, M_c_max=15.0,
        m_1_min=m1_prior.xmin,
        m_1_max=m1_prior.xmax
    ),
]

likelihood = ZeroLikelihood()

mass_matrix = jnp.eye(len(prior.base_prior))
local_sampler_arg = {"step_size": mass_matrix * 3e-3}

Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1)

n_epochs = 2
n_loop_training = 1
learning_rate = 1e-4


jim = Jim(
    likelihood,
    prior,
    sample_transforms=sample_transforms,
    likelihood_transforms=likelihood_transforms,
    n_loop_training=n_loop_training,
    n_loop_production=1,
    n_local_steps=5,
    n_global_steps=100,
    n_chains=100,
    n_epochs=n_epochs,
    learning_rate=learning_rate,
    n_max_examples=30,
    n_flow_samples=100,
    momentum=0.9,
    batch_size=100,
    use_global=True,
    train_thinning=1,
    output_thinning=1,
    local_sampler_arg=local_sampler_arg,
    strategies=["default"],
)

print("Start sampling")
key = jax.random.PRNGKey(42)
jim.sample(key)
jim.print_summary()
samples = jim.get_samples()

for transform in likelihood_transforms:
    samples = jax.vmap(transform.forward)(samples)

import matplotlib
matplotlib.use("agg")
matplotlib.rcParams.update(
    {'font.size': 16,
     'text.usetex': True,
     'font.family': 'Times New Roman'}
)
import matplotlib.pyplot as plt
plt.figure(1)
plt.xlim([5, 53])
plt.ylim([1.8, 18])
# drawing mass ratio lines
import numpy as np
x = np.linspace(0., 100.)
plt.plot(x, 1 * x, color='k', linestyle='--')
plt.plot(x, 0.125 * x, color='k', linestyle='--')
plt.plot(x, Mc_m1_to_m2(5., x)[0].real, color='k', linestyle='--')
plt.plot(x, Mc_m1_to_m2(15., x)[0].real, color='k', linestyle='--')
plt.scatter(samples['m_1'], samples['m_2'])
plt.xlabel(r'$m_1 [M_\odot]$')
plt.ylabel(r'$m_2 [M_\odot]$')
plt.savefig('test_m1_m2.png', bbox_inches='tight')

@tsunhopang tsunhopang requested a review from kazewong October 15, 2024 17:04
@tsunhopang tsunhopang self-assigned this Oct 15, 2024
@@ -24,6 +26,131 @@
)


@jaxtyped(typechecker=typechecker)
class UniformInComponentMassSecondaryMassTransform(ConditionalBijectiveTransform):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a transform or a prior? Also, is this used in the testing script?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a transform

Copy link
Owner

@kazewong kazewong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two minor comments for this PR. Once they are addressed I am happy to merge this

@@ -38,6 +38,31 @@ def inner_product(
return 4.0 * jnp.real(trapezoid(integrand, dx=df))


def Mc_m1_to_m2(Mc: Float, m1: Float):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return typing hint is missing here

@tsunhopang
Copy link
Collaborator Author

Added the transform into the Pv2 testing script and updated the Mc_m1_to_m2 function

@tsunhopang tsunhopang closed this Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants