Skip to content

Commit

Permalink
Merge pull request #141 from xuyuon/fix-run-manager-bugs
Browse files Browse the repository at this point in the history
Fix run manager bugs
  • Loading branch information
xuyuon authored Aug 30, 2024
2 parents 863f34c + 5d1458e commit ca9f906
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ def __init__(self, **kwargs):
raise ValueError("Injection mode requires injection parameters.")

local_prior = self.initialize_prior()
local_likelihood = self.initialize_likelihood(local_prior)
sample_transforms, likelihood_transforms = self.initialize_transforms()
local_likelihood = self.initialize_likelihood(
local_prior, sample_transforms, likelihood_transforms
)
self.jim = Jim(
local_likelihood,
local_prior,
Expand All @@ -124,7 +126,12 @@ def load_from_path(self, path: str) -> SingleEventRun:

### Initialization functions ###

def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLiklihood:
def initialize_likelihood(
self,
prior: prior.CombinePrior,
sample_transforms: transforms.Transform,
likelihood_transforms: transforms.Transform,
) -> SingleEventLiklihood:
"""
Since prior contains information about types, naming and ranges of parameters,
some of the likelihood class require the prior to be initialized, such as the
Expand Down Expand Up @@ -176,6 +183,8 @@ def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLikliho
detectors,
waveform,
prior=prior,
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
**self.run.likelihood_parameters,
**self.run.data_parameters,
)
Expand Down Expand Up @@ -221,6 +230,7 @@ def initialize_transforms(
transform_class = getattr(transforms, transform["name"])
except AttributeError:
raise ValueError(f"{transform['name']} not recognized.")
transform = transform.copy()
transform.pop("name")
sample_transforms.append(transform_class(**transform))
if self.run.likelihood_transforms:
Expand All @@ -239,6 +249,7 @@ def initialize_transforms(
transform_class = getattr(transforms, transform["name"])
except AttributeError:
raise ValueError(f"{transform['name']} not recognized.")
transform = transform.copy()
transform.pop("name")
likelihood_transforms.append(transform_class(**transform))
return sample_transforms, likelihood_transforms
Expand Down Expand Up @@ -460,9 +471,13 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs):
def save_summary(self, path: str = "", **kwargs):
if path == "":
path = self.run.path + "run_manager_summary.txt"
orig_stdout = sys.stdout
sys.stdout = open(path, "wt")
self.jim.print_summary()
for detector, SNR in zip(self.detectors, self.SNRs):
print("SNR of detector " + detector + " is " + str(SNR))
networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5)
print("network SNR is", networkSNR)
if self.run.injection:
for detector, SNR in zip(self.detectors, self.SNRs):
print("SNR of detector " + detector + " is " + str(SNR))
networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5)
print("network SNR is", networkSNR)
sys.stdout.close()
sys.stdout = orig_stdout

0 comments on commit ca9f906

Please sign in to comment.