Skip to content

Commit

Permalink
Merge branch '41-add-more-prior-classes-and-add-composite-prior-example'
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Nov 28, 2023
2 parents fdc89f1 + 84abc7a commit 1e16222
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 23 deletions.
6 changes: 3 additions & 3 deletions example/GW150914_PV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@
jim = Jim(
likelihood,
prior,
n_loop_training=100,
n_loop_training=200,
n_loop_production=10,
n_local_steps=300,
n_global_steps=300,
n_chains=500,
n_epochs=300,
learning_rate=0.001,
max_samples = 60000,
max_samples = 10000,
momentum=0.9,
batch_size=60000,
batch_size=10000,
use_global=True,
keep_quantile=0.,
train_thinning=1,
Expand Down
111 changes: 91 additions & 20 deletions example/GW150914_PV2_newglobal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
jax.config.update("jax_enable_x64", True)

###########################################
########## First we grab data #############
########## This script is experimental ####
###########################################

total_time_start = time.time()
Expand All @@ -30,50 +30,121 @@

waveform = RippleIMRPhenomPv2(f_ref=20)

Mc_prior = Unconstrained_Uniform(10., 80., naming=["M_c"])
q_prior = Unconstrained_Uniform(0.125, 1., naming=["q"], transforms={"q": ("eta", lambda params: params['q']/(1+params['q'])**2)})
Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"])
q_prior = Unconstrained_Uniform(
0.125,
1.0,
naming=["q"],
transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)},
)
s1_prior = Sphere("s1")
s2_prior = Sphere("s2")
dL_prior = Unconstrained_Uniform(0., 2000., naming=["d_L"])
dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"])
t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"])
phase_c_prior = Unconstrained_Uniform(0., 2*jnp.pi, naming=["phase_c"])
cos_iota_prior = Unconstrained_Uniform(-1., 1., naming=["cos_iota"], transforms={"cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi))})
psi_prior = Unconstrained_Uniform(0., jnp.pi, naming=["psi"])
ra_prior = Unconstrained_Uniform(0., 2*jnp.pi, naming=["ra"])
sin_dec_prior = Unconstrained_Uniform(-1., 1., naming=["sin_dec"], transforms={"sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))})
phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"])
cos_iota_prior = Unconstrained_Uniform(
-1.0,
1.0,
naming=["cos_iota"],
transforms={
"cos_iota": (
"iota",
lambda params: jnp.arccos(
jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi
),
)
},
)
psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"])
ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"])
sin_dec_prior = Unconstrained_Uniform(
-1.0,
1.0,
naming=["sin_dec"],
transforms={
"sin_dec": (
"dec",
lambda params: jnp.arcsin(
jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi
),
)
},
)

prior = Composite([Mc_prior, q_prior, s1_prior, s2_prior, dL_prior, t_c_prior, phase_c_prior, cos_iota_prior, psi_prior, ra_prior, sin_dec_prior])
prior = Composite(
[
Mc_prior,
q_prior,
s1_prior,
s2_prior,
dL_prior,
t_c_prior,
phase_c_prior,
cos_iota_prior,
psi_prior,
ra_prior,
sin_dec_prior,
]
)

likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)
optimization_bounds = jnp.array(
[
[-10.0, 10.0],
[-10.0, 10.0],
[0.0, 2.0 * jnp.pi],
[-1.0, 1.0],
[0.01, 1.0],
[0.0, 2.0 * jnp.pi],
[-1.0, 1.0],
[0.01, 1.0],
[-10.0, 10.0],
[-30.0, 30.0],
[-10.0, 10.0],
[-10.0, 10.0],
[-10.0, 10.0],
[-10.0, 10.0],
[-10.0, 10.0],
]
)

likelihood = TransientLikelihoodFD(
[H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2
)


mass_matrix = jnp.eye(prior.n_dim)
# mass_matrix = mass_matrix.at[1, 1].set(1e-3)
# mass_matrix = mass_matrix.at[9, 9].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[9, 9].set(1e-3)
mass_matrix = mass_matrix * 3e-3
local_sampler_arg = {"step_size": mass_matrix}


jim = Jim(
likelihood,
prior,
n_loop_training=50,
n_loop_training=20,
n_loop_production=10,
n_local_steps=300,
n_global_steps=300,
n_chains=500,
n_epochs=300,
learning_rate=0.001,
max_samples = 60000,
max_samples=60000,
momentum=0.9,
batch_size=30000,
use_global=True,
keep_quantile=0.,
keep_quantile=0.0,
train_thinning=1,
output_thinning=30,
local_sampler_arg=local_sampler_arg,
num_layers = 6,
hidden_size = [32,32],
num_bins = 8
num_layers=6,
hidden_size=[32, 32],
num_bins=8,
flowHMC_params={
"step_size": 1e-2,
"n_leapfrog": 3,
"condition_matrix": jnp.linalg.inv(mass_matrix),
},
)

# jim.maximize_likelihood([prior.xmin, prior.xmax])
Expand Down

0 comments on commit 1e16222

Please sign in to comment.