Skip to content

Commit

Permalink
Update hierarchical.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WuShichao authored Aug 31, 2023
1 parent 116e5bc commit cb5a9c4
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions pycbc/inference/models/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from .base import BaseModel
from .relbin import RelativeTimeDom
from .relbin_cpu import snr_predictor_dom
import tqdm

#
# =============================================================================
Expand Down Expand Up @@ -627,12 +626,11 @@ class MultibandRelativeTimeDom(HierarchicalModel):
def __init__(self, variable_params, submodels, **kwargs):
super().__init__(variable_params, submodels, **kwargs)

# We assume the ground-based submodel as the primary model.
# assume the ground-based submodel as the primary model
self.primary_model = self.submodels[kwargs['primary_lbl'][0]]
self.other_models = self.submodels.copy()
self.other_models.pop(kwargs['primary_lbl'][0])
self.other_models = list(self.other_models.values())
self.other_models_labels = kwargs['others_lbls']

def write_metadata(self, fp, group=None):
"""Adds metadata to the output files
Expand Down Expand Up @@ -661,8 +659,8 @@ def write_metadata(self, fp, group=None):
except AttributeError:
pass

def _loglr(self):
r"""Computes the log likelihood ratio,
def total_loglr(self):
r"""Computes the total log likelihood ratio,
.. math::
Expand All @@ -681,47 +679,69 @@ def _loglr(self):

# note that for SOBHB signals, ground-based detectors dominant SNR
# and accuracy of (tc, ra, dec)
sh_primary, hh_primary = self.primary_model._loglr(just_sh_hh=True)
sh_primary, hh_primary = self.primary_model.get_sh_hh()
# print("sh_primary: ", sh_primary)
# print("len(sh_primary): ", len(sh_primary))

nums = self.primary_model.vsamples
margin_params = self.primary_model.marginalize_vector_params.copy()
margin_params.pop('logw_partial')
# margin_params_transformed = sefl.waveform_transforms()

# add likelihood contribution from space-borne detectors, we
# calculate sh/hh for each marginalized parameter
logging.info("Calculating sh/hh for space-borne detector(s)")
sh_others = numpy.full(nums, 0 + 0.0j)
hh_others = numpy.zeros(nums)

for label_i, other_model in enumerate(self.other_models):
logging.info("============= %s =============",
self.other_models_labels[label_i])
# we need static params for other models
current_params_other = other_model.current_params.copy()
# there are still some values in margin_params
# there are still params from margin_params, need investigation
for p in margin_params:
current_params_other.pop(p)
for i in tqdm.tqdm(range(nums)):
# for p in margin_params_transformed:
# current_params_other.pop(p)
print("current_params: ", current_params_other)
for i in range(nums):
parameters = current_params_other.copy()
parameters.update(
{key: value[i] for key, value in margin_params.items()})
{key: value[i] for key, value in
margin_params.items()})
print("parameters: ", parameters)
other_model.update(**parameters)
sh_others[i], hh_others[i] = other_model._loglr(just_sh_hh=True)

sh_total = sh_primary + sh_others
hh_total = hh_primary + hh_others
sh_others[i], hh_others[i] = other_model.get_sh_hh()

# # just a demo to show can't do transform itself
# for label_i, other_model in enumerate(self.other_models):
# current_params_primary = self.primary_model.current_params.copy()
# print("current_params: ", current_params_primary)
# for i in range(nums):
# parameters = current_params_primary
# other_model.update(**parameters)
# sh_others[i], hh_others[i] = other_model.get_sh_hh()

# sh_total = sh_primary + sh_others
# hh_total = hh_primary + hh_others
sh_total = sh_others
hh_total = hh_others
# calculate marginalize_vector_weights
self.primary_model.marginalize_vector_weights = \
- numpy.log(self.primary_model.vsamples)
loglr = self.primary_model.marginalize_loglr(sh_total, hh_total)
return loglr

def _loglikelihood(self):
other_models_lognl = 0
others_lognl = 0
for lbl, model in self.submodels.items():
print("lbl: ", lbl)
print("self.param_map[lbl]: ", self.param_map[lbl])
# Update the parameters of each
model.update(**{p.subname: self.current_params[p.fullname]
for p in self.param_map[lbl]})
other_models_lognl += model.lognl
others_lognl += model.lognl

# calculate the combined loglikelihood
logl = self._loglr() + self.primary_model.lognl + other_models_lognl
logl = self.total_loglr() + self.primary_model.lognl + others_lognl

# store any extra stats from the submodels
for lbl, model in self.submodels.items():
Expand Down

0 comments on commit cb5a9c4

Please sign in to comment.