From 4f3be7043e5df159336b5427f80bf50b2e657828 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Wed, 24 Jul 2024 14:29:05 -0400 Subject: [PATCH 1/7] Updated runManager.py --- src/jimgw/single_event/runManager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 3f65166d..aa8d0dc7 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,7 +71,9 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] + injection_parameters: dict[str, float] = field( + default_factory=lambda: {} + ) injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -123,6 +125,9 @@ def __init__(self, **kwargs): print("Neither run instance nor path provided.") raise ValueError + if self.run.injection and not self.run.injection_parameters: + raise ValueError("Injection mode requires injection parameters.") + local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) From c6605f68996834e5e914a3be229801ea9389978b Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 02:34:35 +0800 Subject: [PATCH 2/7] Update transforms.py --- src/jimgw/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 715d49de..ac56a1e1 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -89,7 +89,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -126,7 +126,7 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], From 59dd222e8cb1f67dad80e7d4119f7c505bea11d8 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 02:45:23 +0800 Subject: [PATCH 3/7] Update runManager.py --- src/jimgw/single_event/runManager.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index aa8d0dc7..e23c5396 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,9 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] = field( - default_factory=lambda: {} - ) + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -125,9 +123,6 @@ def __init__(self, **kwargs): print("Neither run instance nor path provided.") raise ValueError - if self.run.injection and not self.run.injection_parameters: - raise ValueError("Injection mode requires injection parameters.") - local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) From 702ee20a6104b8a9561f2ba115809472985c49a0 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 02:45:58 +0800 Subject: [PATCH 4/7] Update runManager.py --- src/jimgw/single_event/runManager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index e23c5396..3f65166d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,7 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} From b1133d36f4e6ff20667b30be9d7a91d58419d562 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:07:06 +0800 Subject: [PATCH 5/7] Update runManager.py --- src/jimgw/single_event/runManager.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index aa8d0dc7..0a4b502d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,9 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] = field( - default_factory=lambda: {} - ) + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -125,9 +123,6 @@ def __init__(self, **kwargs): print("Neither run instance nor path provided.") raise ValueError - if self.run.injection and not self.run.injection_parameters: - raise ValueError("Injection mode requires injection parameters.") - local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) From 2a9d696a20d28fdd1693698a1840fef9ab276578 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:07:49 +0800 Subject: [PATCH 6/7] Update runManager.py --- src/jimgw/single_event/runManager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 0a4b502d..3f65166d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,7 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} From 2adcabf0bbdbee0ee42b163c8ba51a45a4f1b9a6 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:11:25 +0800 Subject: [PATCH 7/7] Updated test_prior.py --- test/unit/test_prior.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index 852ded16..5fbcf3c3 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -43,11 +43,8 @@ def test_sine(self): log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support - x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) - y = jax.vmap(p.base_prior.base_prior.transform)(x) - y = jax.vmap(p.base_prior.transform)(y) - y = jax.vmap(p.transform)(y) - assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.sin(y['x'])/2.0)) + samples = samples['x'] + assert jnp.allclose(log_prob, jnp.log(jnp.sin(samples)/2.0)) def test_cosine(self): p = CosinePrior(["x"]) @@ -57,11 +54,8 @@ def test_cosine(self): # Check that the log_prob is finite log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) - # Check that the log_prob is correct in the support - x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) - y = jax.vmap(p.base_prior.transform)(x) - y = jax.vmap(p.transform)(y) - assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.cos(y['x'])/2.0)) + samples = samples['x'] + assert jnp.allclose(log_prob, jnp.log(jnp.cos(samples)/2.0)) def test_uniform_sphere(self): p = UniformSpherePrior(["x"])