Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing user to keep parameters after likelihood transform #143

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
178 commits
Select commit Hold shift + click to select a range
77e0247
Update Single_event_runManager.py
thomasckng Jul 23, 2024
86da005
Merge pull request #105 from kazewong/main
kazewong Jul 23, 2024
540b4db
Merge branch 'kazewong:jim-dev' into jim-dev
thomasckng Jul 23, 2024
8e6d789
Merge pull request #102 from thomasckng/jim-dev
kazewong Jul 23, 2024
75c74a3
scaffolding jim to handle naming.
kazewong Jul 24, 2024
c749dbd
Rename variables in jim
kazewong Jul 24, 2024
f45b41e
Starting tracking names in jim
kazewong Jul 24, 2024
dd6e0de
Add Transform class
kazewong Jul 24, 2024
4f3be70
Updated runManager.py
xuyuon Jul 24, 2024
f11f8c8
Add LogitToUniform Transform
kazewong Jul 24, 2024
c1192f2
Add logit distribution
kazewong Jul 24, 2024
b3ebc5e
Add propagate)name method in transform
kazewong Jul 24, 2024
95ef89b
Rename composite to Combine for combining priors
kazewong Jul 24, 2024
e6e45ef
scaffold prior test
kazewong Jul 24, 2024
d503a37
Add Sequential Transform
kazewong Jul 24, 2024
61b3b9e
Separate logit and scaling.
kazewong Jul 24, 2024
b47e6ae
Univeriate Transform seems working
kazewong Jul 24, 2024
e1cc408
Uniform prior now working.
kazewong Jul 24, 2024
d68cb36
Added inverse transform and uniform now perform correct
kazewong Jul 24, 2024
c16eef5
Add comments to current sequential transform prior class
kazewong Jul 24, 2024
8ab92a4
Removing inverse.
kazewong Jul 25, 2024
d8b2d1f
Add transformation function
kazewong Jul 25, 2024
ca1d6b6
Combine should be working now
kazewong Jul 25, 2024
2f3f412
Sine is an illegal transform since its Jacobian could be negative
kazewong Jul 25, 2024
4978cb1
Modify Uniform and add UniformSphere
thomasckng Jul 25, 2024
c1115bd
Add Sine and Cosine Prior
thomasckng Jul 26, 2024
01a6c1e
Revert "Modify Uniform and add UniformSphere"
thomasckng Jul 26, 2024
7ee4133
Add standard normal distribution
thomasckng Jul 26, 2024
6a35792
Add periodic uniform prior
thomasckng Jul 26, 2024
1bcf32c
Reformat
thomasckng Jul 26, 2024
17450be
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 26, 2024
5d98aeb
Revert "Merge branch '98-moving-naming-tracking-into-jim-class-from-p…
thomasckng Jul 26, 2024
87ee212
Minor text change
thomasckng Jul 26, 2024
cc1448e
Remove PeriodicUniform
thomasckng Jul 26, 2024
6070f13
Use self.sample_base
thomasckng Jul 26, 2024
8a301f2
Update prior.py
thomasckng Jul 26, 2024
d761f6b
Update prior.py to include powerLaw
xuyuon Jul 26, 2024
b9a7255
Update transforms.py
xuyuon Jul 26, 2024
2f6e12a
format and updating typing hint
kazewong Jul 25, 2024
194c565
Revert "Update transforms.py"
thomasckng Jul 26, 2024
5807f3c
Revert "Update prior.py to include powerLaw"
thomasckng Jul 26, 2024
98625c3
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 26, 2024
8256d0e
Comment out old prior
thomasckng Jul 26, 2024
b42b0f4
Reformat
thomasckng Jul 26, 2024
ba7dabb
Merge pull request #111 from thomasckng/sphere_prior
thomasckng Jul 26, 2024
a9629b2
Update Prior naming
kazewong Jul 26, 2024
fbba5af
Standize Transform naming
kazewong Jul 26, 2024
4a5d932
Fixing naming problem and add base_distribution tracer
kazewong Jul 26, 2024
3497b6d
Merge pull request #2 from kazewong/98-moving-naming-tracking-into-ji…
thomasckng Jul 26, 2024
27dc870
Updated powerlaw
xuyuon Jul 26, 2024
f7b883d
Updated powerlaw
xuyuon Jul 26, 2024
1665b1c
Updated powerlaw
xuyuon Jul 26, 2024
f51847d
Updated powerlaw
xuyuon Jul 26, 2024
f7876e6
Updated prior.py to include UniformComponenChirpMassPrior
xuyuon Jul 26, 2024
0b2c7cf
Fix priors
thomasckng Jul 27, 2024
d917f96
Add prior tests
thomasckng Jul 27, 2024
9390495
Merge branch 'prior-test' into 98-moving-naming-tracking-into-jim-cla…
thomasckng Jul 27, 2024
142fe54
Set constraint on powerlaw prior input
xuyuon Jul 29, 2024
fa6a6e2
Added test_power_law
xuyuon Jul 29, 2024
a8c0673
Updated ParetoTransform to avoid divide by zero
xuyuon Jul 29, 2024
6ce5b36
Revert update on ParetoTransform
xuyuon Jul 29, 2024
e9511b9
Updated test_prior.py
xuyuon Jul 29, 2024
54082ea
Updated test_prior.py
xuyuon Jul 29, 2024
9ad64a1
Updated test_prior.py
xuyuon Jul 29, 2024
ce437c4
Updated test_prior.py
xuyuon Jul 29, 2024
18d5184
Updated test_prior.py
xuyuon Jul 29, 2024
7c05d14
Updated test_prior.py
xuyuon Jul 29, 2024
f647bf0
Remove unnecessary test
thomasckng Jul 29, 2024
6171ce5
Reformat
thomasckng Jul 29, 2024
ab7447b
Change naming
thomasckng Jul 29, 2024
1b67056
Removed bilby
xuyuon Jul 29, 2024
b2dd606
Merge pull request #113 from xuyuon/98-moving-naming-tracking-into-ji…
kazewong Jul 29, 2024
9d59c4d
Move unit test to unit directory
kazewong Jul 29, 2024
b3110df
Update priors naming issue
kazewong Jul 29, 2024
024c5c3
base parameter names seem working
kazewong Jul 29, 2024
ac43d54
fix cosine naming error
kazewong Jul 29, 2024
151c37d
prior test now all pass
kazewong Jul 29, 2024
576ce7c
Fix likelihood sampling issue due to parameter name transformation
kazewong Jul 29, 2024
7625dbe
Instead of defining univariate and multivariate,
kazewong Jul 30, 2024
9ac722f
Adding some transform, should be working
kazewong Jul 30, 2024
935b314
Refactor into NtoN and NtoM transform
kazewong Jul 30, 2024
685670a
Fix bugs in ArcSine
kazewong Jul 30, 2024
39126f5
Add inverse mass transform
thomasckng Jul 30, 2024
46bd044
Update sequential prior class and test_GW15014.py.
kazewong Jul 31, 2024
a0161ee
Merge pull request #5 from kazewong/98-moving-naming-tracking-into-ji…
thomasckng Jul 31, 2024
8e1441b
correct sign errors and minior bugs
kazewong Jul 31, 2024
dfdfffa
Add mass transform
thomasckng Jul 31, 2024
dd03bf6
Fixed powerLaw
xuyuon Jul 31, 2024
735415e
Fixed powerLaw
xuyuon Jul 31, 2024
cec7447
Fixed powerLaw
xuyuon Jul 31, 2024
74e553d
Fixed powerLaw
xuyuon Jul 31, 2024
9d87e58
Add simplex transform
thomasckng Jul 31, 2024
5f33346
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
52fa0eb
Change transform to take dictionary as input for transform_func
kazewong Jul 31, 2024
2aef04b
Updated test_prior.py
xuyuon Jul 31, 2024
1e389bb
Add UniformComponentMassPrior
thomasckng Jul 31, 2024
0252392
prior system should work with dictionary function now
kazewong Jul 31, 2024
cb75bb9
update transform
kazewong Jul 31, 2024
fd32f20
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
561e628
Move subclassing structure
kazewong Jul 31, 2024
1aaf5ea
Remove transformation
thomasckng Jul 31, 2024
5bfdda1
Remove prior
thomasckng Jul 31, 2024
77a6ad1
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
2401df8
Add mass transform
thomasckng Jul 31, 2024
f932c77
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
ff35a82
Add Bound transforming transform
kazewong Jul 31, 2024
47af9cf
no nans now, but the code can be consolidate
kazewong Aug 1, 2024
1c45acb
Merge branch 'kazewong:98-moving-naming-tracking-into-jim-class-from-…
xuyuon Aug 1, 2024
b5e06a6
Solve conflict
thomasckng Aug 1, 2024
b7d08d3
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Aug 1, 2024
255fad3
Fixed powerLaw
xuyuon Aug 1, 2024
fcdc120
Fixed powerLaw
xuyuon Aug 1, 2024
06eb3ad
Reformatted
xuyuon Aug 1, 2024
8e5b326
Add sky position transform
thomasckng Aug 1, 2024
10d51b2
Modify sky position transform
thomasckng Aug 1, 2024
8368d00
Change util func name
thomasckng Aug 1, 2024
b4f6052
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Aug 1, 2024
1cb0d11
Revert "Merge branch '98-moving-naming-tracking-into-jim-class-from-p…
thomasckng Aug 1, 2024
78e93c5
Merge
thomasckng Aug 1, 2024
080bd8b
Modify integration test
thomasckng Aug 1, 2024
4058d32
Reformat
thomasckng Aug 1, 2024
cdf771d
Add typecheck
thomasckng Aug 1, 2024
02c5650
minor typo
thomasckng Aug 1, 2024
ce7ac34
Rename sampler
thomasckng Aug 1, 2024
7d44aa4
Fix test
thomasckng Aug 1, 2024
7826122
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
xuyuon Aug 1, 2024
aab31ff
Merge pull request #1 from kazewong/98-moving-naming-tracking-into-ji…
xuyuon Aug 1, 2024
45be9ae
Revert "Fixed PowerLaw prior"
xuyuon Aug 1, 2024
0acb87c
Merge pull request #120 from kazewong/revert-115-98-moving-naming-tra…
xuyuon Aug 1, 2024
e9288c8
Fix BoundToUnbound transform
thomasckng Aug 1, 2024
114178f
Merge branch 'kazewong:98-moving-naming-tracking-into-jim-class-from-…
xuyuon Aug 1, 2024
d5a34f2
Updated prior.py and transforms.py
xuyuon Aug 1, 2024
34f3d26
Updated test_prior.py
xuyuon Aug 1, 2024
fe500a7
Use ifos list
thomasckng Aug 1, 2024
0d28520
Fix jim summary and get_samples
thomasckng Aug 1, 2024
f7e3fe8
Fix jim output functions
thomasckng Aug 1, 2024
68bef54
Modify Transform
thomasckng Aug 1, 2024
87593e1
Fix jim output
thomasckng Aug 1, 2024
730fe31
Add comment
thomasckng Aug 1, 2024
ed727af
Updated test_prior.py
xuyuon Aug 2, 2024
9d191da
Updated test_prior.py
xuyuon Aug 2, 2024
957335e
Merge pull request #121 from xuyuon/98-moving-naming-tracking-into-ji…
kazewong Aug 2, 2024
5a5ff2f
Add sky position transform
thomasckng Aug 2, 2024
277e893
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
kazewong Aug 2, 2024
bd915ef
Add powerlaw transform back
kazewong Aug 2, 2024
ba86a65
Move single_event prior and transform
thomasckng Aug 2, 2024
3e1ea71
Tidy up test
thomasckng Aug 2, 2024
6ad882a
Add utils.py
thomasckng Aug 2, 2024
ede2b99
Move log_i0
thomasckng Aug 2, 2024
e764696
Fixing check
thomasckng Aug 2, 2024
9d9c833
Merge pull request #118 from thomasckng/transform
kazewong Aug 2, 2024
093dd09
Updated jim.py
xuyuon Aug 2, 2024
ceb0b7f
Added spin transform
xuyuon Aug 2, 2024
b972aea
Updated transform.py
xuyuon Aug 2, 2024
d4a4386
Updated test_GW150914_pv2.py
xuyuon Aug 2, 2024
0d6a5e5
Updated transform.py
xuyuon Aug 2, 2024
91f2f89
Updated test_GW150914_pv2.py
xuyuon Aug 2, 2024
6fc8887
Updated utils.py
xuyuon Aug 2, 2024
49d604d
Fix mass transform test
thomasckng Aug 2, 2024
3b7f9e8
Reformated
xuyuon Aug 2, 2024
ac0b1f5
Use PowerLaw for distance
thomasckng Aug 2, 2024
e0a61fa
Reformated
xuyuon Aug 2, 2024
6374772
Rename test_GW150914_pv2.py to test_GW150914_PV2.py
xuyuon Aug 2, 2024
fa134f9
Add heterodyne test
thomasckng Aug 2, 2024
2c9c6a6
Fix typecheck
thomasckng Aug 2, 2024
15e4074
Shorten test runtime
thomasckng Aug 2, 2024
87440db
Merge pull request #122 from xuyuon/98-moving-naming-tracking-into-ji…
kazewong Aug 2, 2024
e1800da
Merge pull request #123 from thomasckng/transform
kazewong Aug 2, 2024
174d43a
Fix jim output functions
thomasckng Aug 3, 2024
14121d7
Delete test_GW150914_PV2.py
thomasckng Aug 3, 2024
0a9f58f
Create test_GW150914_Pv2.py
thomasckng Aug 3, 2024
1b79a3a
Merge pull request #127 from thomasckng/transform
kazewong Aug 4, 2024
e803823
Merge pull request #6 from kazewong/98-moving-naming-tracking-into-ji…
xuyuon Aug 21, 2024
c6605f6
Update transforms.py
xuyuon Aug 21, 2024
59dd222
Update runManager.py
xuyuon Aug 21, 2024
702ee20
Update runManager.py
xuyuon Aug 21, 2024
7910785
Merge pull request #137 from xuyuon/fix-jacobian
kazewong Aug 22, 2024
864f13f
Allowing user to keep parameters after the likelihood transform
tsunhopang Aug 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@

name: Python package

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
on: [push, pull_request]

jobs:
build:
Expand Down
Empty file added docs/tutorials/naming_system.md
Empty file.
1 change: 0 additions & 1 deletion example/Single_event_runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

run = SingleEventRun(
seed=0,
path="test_data/GW150914/",
detectors=["H1", "L1"],
priors={
"M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
Expand Down
137 changes: 107 additions & 30 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,54 @@

from jimgw.base import LikelihoodBase
from jimgw.prior import Prior
from jimgw.transforms import BijectiveTransform, NtoMTransform


class Jim(object):
"""
Master class for interfacing with flowMC

"""

def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):
self.Likelihood = likelihood
self.Prior = prior
likelihood: LikelihoodBase
prior: Prior

# Name of parameters to sample from
sample_transforms: list[BijectiveTransform]
likelihood_transforms: list[NtoMTransform]
parameter_names: list[str]
sampler: Sampler
parameters_to_keep: list[str]

def __init__(
self,
likelihood: LikelihoodBase,
prior: Prior,
sample_transforms: list[BijectiveTransform] = [],
likelihood_transforms: list[NtoMTransform] = [],
parameters_to_keep: list[str] = [],
**kwargs,
):
self.likelihood = likelihood
self.prior = prior

self.sample_transforms = sample_transforms
self.likelihood_transforms = likelihood_transforms
self.parameter_names = prior.parameter_names
self.parameters_to_keep = parameters_to_keep

if len(sample_transforms) == 0:
print(
"No sample transforms provided. Using prior parameters as sampling parameters"
)
else:
print("Using sample transforms")
for transform in sample_transforms:
self.parameter_names = transform.propagate_name(self.parameter_names)

if len(likelihood_transforms) == 0:
print(
"No likelihood transforms provided. Using prior parameters as likelihood parameters"
)

seed = kwargs.get("seed", 0)

Expand All @@ -33,30 +70,65 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):

rng_key, subkey = jax.random.split(rng_key)
model = MaskedCouplingRQSpline(
self.Prior.n_dim, num_layers, hidden_size, num_bins, subkey
self.prior.n_dim, num_layers, hidden_size, num_bins, subkey
)

self.Sampler = Sampler(
self.Prior.n_dim,
self.sampler = Sampler(
self.prior.n_dim,
rng_key,
None, # type: ignore
local_sampler,
model,
**kwargs,
)

def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
"""
Turn an array into a dictionary

Parameters
----------
x : Array
An array of parameters. Shape (n_dim,).
"""

return dict(zip(self.parameter_names, x))

def posterior(self, params: Float[Array, " n_dim"], data: dict):
prior_params = self.Prior.add_name(params.T)
prior = self.Prior.log_prob(prior_params)
return (
self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior
named_params = self.add_name(params)
transform_jacobian = 0.0
for transform in reversed(self.sample_transforms):
named_params, jacobian = transform.inverse(named_params)
transform_jacobian += jacobian
prior = self.prior.log_prob(named_params) + transform_jacobian

# make a copy of the named_params
named_params_copy = named_params.copy()
# do the likelihood transform
for transform in self.likelihood_transforms:
named_params = transform.forward(named_params)
# add back the parameters
jax.tree.map(
lambda key: named_params.update({key: named_params_copy[key]}),
self.parameters_to_keep,
)

def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess.size == 0:
initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
self.Sampler.sample(initial_guess, None) # type: ignore
return self.likelihood.evaluate(named_params, data) + prior

def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
if initial_position.size == 0:
initial_guess = []
for _ in range(self.sampler.n_chains):
flag = True
while flag:
key = jax.random.split(key)[1]
guess = self.prior.sample(key, 1)
for transform in self.sample_transforms:
guess = transform.forward(guess)
guess = jnp.array([i for i in guess.values()]).T[0]
flag = not jnp.all(jnp.isfinite(guess))
initial_guess.append(guess)
initial_position = jnp.array(initial_guess)
self.sampler.sample(initial_position, None) # type: ignore

def maximize_likelihood(
self,
Expand All @@ -67,7 +139,7 @@ def maximize_likelihood(
):
key = jax.random.PRNGKey(seed)
set_nwalkers = set_nwalkers
initial_guess = self.Prior.sample(key, set_nwalkers)
initial_guess = self.prior.sample(key, set_nwalkers)

def negative_posterior(x: Float[Array, " n_dim"]):
return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet
Expand All @@ -78,7 +150,7 @@ def negative_posterior(x: Float[Array, " n_dim"]):
print("Done compiling")

print("Starting the optimizer")
optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True)
optimizer = EvolutionaryOptimizer(self.prior.n_dim, verbose=True)
_ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops)
best_fit = optimizer.get_result()[0]
return best_fit
Expand All @@ -89,22 +161,24 @@ def print_summary(self, transform: bool = True):

"""

train_summary = self.Sampler.get_sampler_state(training=True)
production_summary = self.Sampler.get_sampler_state(training=False)
train_summary = self.sampler.get_sampler_state(training=True)
production_summary = self.sampler.get_sampler_state(training=False)

training_chain = train_summary["chains"].reshape(-1, self.Prior.n_dim).T
training_chain = self.Prior.add_name(training_chain)
training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T
training_chain = self.add_name(training_chain)
if transform:
training_chain = self.Prior.transform(training_chain)
for sample_transform in reversed(self.sample_transforms):
training_chain = sample_transform.backward(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
training_loss = train_summary["loss_vals"]

production_chain = production_summary["chains"].reshape(-1, self.Prior.n_dim).T
production_chain = self.Prior.add_name(production_chain)
production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T
production_chain = self.add_name(production_chain)
if transform:
production_chain = self.Prior.transform(production_chain)
for sample_transform in reversed(self.sample_transforms):
production_chain = sample_transform.backward(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]
Expand Down Expand Up @@ -156,11 +230,14 @@ def get_samples(self, training: bool = False) -> dict:

"""
if training:
chains = self.Sampler.get_sampler_state(training=True)["chains"]
chains = self.sampler.get_sampler_state(training=True)["chains"]
else:
chains = self.Sampler.get_sampler_state(training=False)["chains"]
chains = self.sampler.get_sampler_state(training=False)["chains"]

chains = self.Prior.transform(self.Prior.add_name(chains.transpose(2, 0, 1)))
chains = chains.transpose(2, 0, 1)
chains = self.add_name(chains)
for sample_transform in reversed(self.sample_transforms):
chains = sample_transform.backward(chains)
return chains

def plot(self):
Expand Down
Loading
Loading