From 47d52756cab62802c8cfc31a485c3705869ebc07 Mon Sep 17 00:00:00 2001 From: "Michael J. Williams" Date: Thu, 14 Dec 2023 15:52:41 +0000 Subject: [PATCH] Add support for nessai sampler in pycbc inference (#4567) * add basic support for nessai sampler * enable all options and resuming in nessai * fix prior bounds in nessai model * tweak resuming and samples in nessai interface * change outdir to avoid namespace conflicts * tweaks to nessai sampler class * fix nessai checkpointing and other minor tweaks * fix for reading in nessai result files * use callback for checkpointing in nessai * start addressing codeclimate issues * add nessai to auxiliary samplers * add additional comments for nessai * make simple sampler example 2d nessai does not support 1d likelihoods, so this change is neede to test nessai in the CI * fix call to rng.random * add nessai to samplers example and update plot * set minimum version for nessai * force cpu-only version of torch * add missing epsie jump proposal * add plot-marginal to samplers plot * fix whitespace * use lazy formatting in logging functions * move functions to common nested class * update for change common nested class * address more code climate issues --- companion.txt | 4 + examples/inference/samplers/epsie_stub.ini | 3 + examples/inference/samplers/nessai_stub.ini | 3 + examples/inference/samplers/run.sh | 9 +- examples/inference/samplers/simp.ini | 6 + pycbc/inference/io/__init__.py | 2 + pycbc/inference/io/dynesty.py | 76 ++-- pycbc/inference/io/nessai.py | 49 +++ pycbc/inference/sampler/__init__.py | 7 + pycbc/inference/sampler/nessai.py | 364 ++++++++++++++++++++ 10 files changed, 483 insertions(+), 40 deletions(-) create mode 100644 examples/inference/samplers/nessai_stub.ini create mode 100644 pycbc/inference/io/nessai.py create mode 100644 pycbc/inference/sampler/nessai.py diff --git a/companion.txt b/companion.txt index a101c774632..19ea6b0bb73 100644 --- a/companion.txt +++ b/companion.txt @@ -14,6 +14,10 @@ cpnest pymultinest ultranest https://github.com/willvousden/ptemcee/archive/master.tar.gz +# Force the cpu-only version of PyTorch +--extra-index-url https://download.pytorch.org/whl/cpu +torch +nessai>=0.11.0 # useful to look at PyCBC Live with htop setproctitle diff --git a/examples/inference/samplers/epsie_stub.ini b/examples/inference/samplers/epsie_stub.ini index 64c210a6045..7953b4ef266 100644 --- a/examples/inference/samplers/epsie_stub.ini +++ b/examples/inference/samplers/epsie_stub.ini @@ -14,3 +14,6 @@ ntemps = 4 [jump_proposal-x] name = normal + +[jump_proposal-y] +name = normal diff --git a/examples/inference/samplers/nessai_stub.ini b/examples/inference/samplers/nessai_stub.ini new file mode 100644 index 00000000000..cdde947b339 --- /dev/null +++ b/examples/inference/samplers/nessai_stub.ini @@ -0,0 +1,3 @@ +[sampler] +name = nessai +nlive = 200 diff --git a/examples/inference/samplers/run.sh b/examples/inference/samplers/run.sh index 5012f0b251b..91e41e7fbb7 100755 --- a/examples/inference/samplers/run.sh +++ b/examples/inference/samplers/run.sh @@ -1,5 +1,5 @@ #!/bin/sh -for f in cpnest_stub.ini emcee_stub.ini emcee_pt_stub.ini dynesty_stub.ini ultranest_stub.ini epsie_stub.ini; do +for f in cpnest_stub.ini emcee_stub.ini emcee_pt_stub.ini dynesty_stub.ini ultranest_stub.ini epsie_stub.ini nessai_stub.ini; do echo $f pycbc_inference \ --config-files `dirname $0`/simp.ini `dirname $0`/$f \ @@ -16,4 +16,9 @@ dynesty_stub.ini.hdf:dynesty \ ultranest_stub.ini.hdf:ultranest \ epsie_stub.ini.hdf:espie \ cpnest_stub.ini.hdf:cpnest \ ---output-file sample.png +nessai_stub.ini.hdf:nessai \ +--output-file sample.png \ +--plot-contours \ +--plot-marginal \ +--no-contour-labels \ +--no-marginal-titles diff --git a/examples/inference/samplers/simp.ini b/examples/inference/samplers/simp.ini index e82f9854ebe..88d883a89ea 100644 --- a/examples/inference/samplers/simp.ini +++ b/examples/inference/samplers/simp.ini @@ -3,8 +3,14 @@ name = test_normal [variable_params] x = +y = [prior-x] name = uniform min-x = -10 max-x = 10 + +[prior-y] +name = uniform +min-y = -10 +max-y = 10 diff --git a/pycbc/inference/io/__init__.py b/pycbc/inference/io/__init__.py index 93261f677be..4b3fd0ce909 100644 --- a/pycbc/inference/io/__init__.py +++ b/pycbc/inference/io/__init__.py @@ -37,6 +37,7 @@ from .multinest import MultinestFile from .dynesty import DynestyFile from .ultranest import UltranestFile +from .nessai import NessaiFile from .posterior import PosteriorFile from .txt import InferenceTXTFile @@ -49,6 +50,7 @@ DynestyFile.name: DynestyFile, PosteriorFile.name: PosteriorFile, UltranestFile.name: UltranestFile, + NessaiFile.name: NessaiFile, } try: diff --git a/pycbc/inference/io/dynesty.py b/pycbc/inference/io/dynesty.py index 1d77fa35608..5a79882f3d3 100644 --- a/pycbc/inference/io/dynesty.py +++ b/pycbc/inference/io/dynesty.py @@ -90,6 +90,44 @@ def extra_args_parser(parser=None, skip_args=None, **kwargs): "extracted instead.") return parser, actions + def write_pickled_data_into_checkpoint_file(self, state): + """Dump the sampler state into checkpoint file + """ + if 'sampler_info/saved_state' not in self: + self.create_group('sampler_info/saved_state') + dump_state(state, self, path='sampler_info/saved_state') + + def read_pickled_data_from_checkpoint_file(self): + """Load the sampler state (pickled) from checkpoint file + """ + return load_state(self, path='sampler_info/saved_state') + + def write_raw_samples(self, data, parameters=None): + """Write the nested samples to the file + """ + if 'samples' not in self: + self.create_group('samples') + write_samples_to_file(self, data, parameters=parameters, + group='samples') + def validate(self): + """Runs a validation test. + This checks that a samples group exist, and that pickeled data can + be loaded. + + Returns + ------- + bool : + Whether or not the file is valid as a checkpoint file. + """ + try: + if 'sampler_info/saved_state' in self: + load_state(self, path='sampler_info/saved_state') + checkpoint_valid = True + except KeyError: + checkpoint_valid = False + return checkpoint_valid + + class DynestyFile(CommonNestedMetadataIO, BaseNestedSamplerFile): """Class to handle file IO for the ``dynesty`` sampler.""" @@ -148,41 +186,3 @@ def read_raw_samples(self, fields, raw_samples=False, seed=0): return post else: return samples - - def write_pickled_data_into_checkpoint_file(self, state): - """Dump the sampler state into checkpoint file - """ - if 'sampler_info/saved_state' not in self: - self.create_group('sampler_info/saved_state') - dump_state(state, self, path='sampler_info/saved_state') - - def read_pickled_data_from_checkpoint_file(self): - """Load the sampler state (pickled) from checkpoint file - """ - return load_state(self, path='sampler_info/saved_state') - - def write_raw_samples(self, data, parameters=None): - """Write the nested samples to the file - """ - if 'samples' not in self: - self.create_group('samples') - write_samples_to_file(self, data, parameters=parameters, - group='samples') - - def validate(self): - """Runs a validation test. - This checks that a samples group exist, and that pickeled data can - be loaded. - - Returns - ------- - bool : - Whether or not the file is valid as a checkpoint file. - """ - try: - if 'sampler_info/saved_state' in self: - load_state(self, path='sampler_info/saved_state') - checkpoint_valid = True - except KeyError: - checkpoint_valid = False - return checkpoint_valid diff --git a/pycbc/inference/io/nessai.py b/pycbc/inference/io/nessai.py new file mode 100644 index 00000000000..86c1bcfba41 --- /dev/null +++ b/pycbc/inference/io/nessai.py @@ -0,0 +1,49 @@ +"""Provides IO for the nessai sampler""" +import numpy + +from .base_nested_sampler import BaseNestedSamplerFile + +from .posterior import read_raw_samples_from_file +from .dynesty import CommonNestedMetadataIO + + +class NessaiFile(CommonNestedMetadataIO, BaseNestedSamplerFile): + """Class to handle file IO for the ``nessai`` sampler.""" + + name = "nessai_file" + + def read_raw_samples(self, fields, raw_samples=False, seed=0): + """Reads samples from a nessai file and constructs a posterior. + + Using rejection sampling to resample the nested samples + + Parameters + ---------- + fields : list of str + The names of the parameters to load. Names must correspond to + dataset names in the file's ``samples`` group. + raw_samples : bool, optional + Return the raw (unweighted) samples instead of the estimated + posterior samples. Default is False. + + Returns + ------- + dict : + Dictionary of parameter fields -> samples. + """ + samples = read_raw_samples_from_file(self, fields) + logwt = read_raw_samples_from_file(self, ['logwt'])['logwt'] + loglikelihood = read_raw_samples_from_file( + self, ['loglikelihood'])['loglikelihood'] + if not raw_samples: + n_samples = len(logwt) + # Rejection sample + rng = numpy.random.default_rng(seed) + logwt -= logwt.max() + logu = numpy.log(rng.random(n_samples)) + keep = logwt > logu + post = {'loglikelihood': loglikelihood[keep]} + for param in fields: + post[param] = samples[param][keep] + return post + return samples diff --git a/pycbc/inference/sampler/__init__.py b/pycbc/inference/sampler/__init__.py index 1b83f52d6cf..41da16f39c6 100644 --- a/pycbc/inference/sampler/__init__.py +++ b/pycbc/inference/sampler/__init__.py @@ -66,6 +66,13 @@ except ImportError: pass +try: + from .nessai import NessaiSampler + samplers[NessaiSampler.name] = NessaiSampler +except ImportError: + pass + + def load_from_config(cp, model, **kwargs): """Loads a sampler from the given config file. diff --git a/pycbc/inference/sampler/nessai.py b/pycbc/inference/sampler/nessai.py new file mode 100644 index 00000000000..9da6c7461c7 --- /dev/null +++ b/pycbc/inference/sampler/nessai.py @@ -0,0 +1,364 @@ +""" +This modules provides class for using the nessai sampler package for parameter +estimation. + +Documentation for nessai: https://nessai.readthedocs.io/en/latest/ +""" +import ast +import logging +import os + +import nessai.flowsampler +import nessai.model +import nessai.livepoint +import nessai.utils.multiprocessing +import nessai.utils.settings +import numpy +import numpy.lib.recfunctions as rfn + +from .base import BaseSampler, setup_output +from .base_mcmc import get_optional_arg_from_config +from ..io import NessaiFile, loadfile +from ...pool import choose_pool + + +class NessaiSampler(BaseSampler): + """Class to construct a FlowSampler from the nessai package.""" + + name = "nessai" + _io = NessaiFile + + def __init__( + self, + model, + loglikelihood_function, + nlive=1000, + nprocesses=1, + use_mpi=False, + run_kwds=None, + extra_kwds=None, + ): + super().__init__(model) + + self.nlive = nlive + self.model_call = NessaiModel(self.model, loglikelihood_function) + + self.extra_kwds = extra_kwds if extra_kwds is not None else {} + self.run_kwds = run_kwds if run_kwds is not None else {} + + nessai.utils.multiprocessing.initialise_pool_variables(self.model_call) + self.pool = choose_pool(mpi=use_mpi, processes=nprocesses) + self.nprocesses = nprocesses + + self._sampler = None + self._nested_samples = None + self._posterior_samples = None + self._logz = None + self._dlogz = None + self.checkpoint_file = None + self.resume_data = None + + @property + def io(self): + return self._io + + @property + def model_stats(self): + pass + + @property + def samples(self): + """The raw nested samples including the corresponding weights""" + if self._sampler.ns.nested_samples: + ns = numpy.array(self._sampler.ns.nested_samples) + samples = nessai.livepoint.live_points_to_dict( + ns, + self.model.sampling_params, + ) + samples["logwt"] = self._sampler.ns.state.log_posterior_weights + samples["loglikelihood"] = ns["logL"] + samples["logprior"] = ns["logP"] + samples["it"] = ns["it"] + else: + samples = {} + return samples + + def run(self, **kwargs): + """Run the sampler""" + out_dir = os.path.join( + os.path.dirname(os.path.abspath(self.checkpoint_file)), + "outdir_nessai", + ) + default_kwds, default_run_kwds = self.get_default_kwds( + importance_nested_sampler=self.extra_kwds.get( + "importance_nested_sampler", False + ) + ) + + extra_kwds = self.extra_kwds.copy() + run_kwds = self.run_kwds.copy() + + if kwargs is not None: + logging.info("Updating keyword arguments with %s", kwargs) + extra_kwds.update( + {k: v for k, v in kwargs.items() if k in default_kwds} + ) + run_kwds.update( + {k: v for k, v in kwargs.items() if k in default_run_kwds} + ) + + if self._sampler is None: + logging.info("Initialising nessai FlowSampler") + self._sampler = nessai.flowsampler.FlowSampler( + self.model_call, + output=out_dir, + pool=self.pool, + n_pool=self.nprocesses, + close_pool=False, + signal_handling=False, + resume_data=self.resume_data, + checkpoint_callback=self.checkpoint_callback, + **extra_kwds, + ) + logging.info("Starting sampling with nessai") + self._sampler.run(**run_kwds) + + @staticmethod + def get_default_kwds(importance_nested_sampler=False): + """Return lists of all allowed keyword arguments for nessai. + + Returns + ------- + default_kwds : list + List of keyword arguments that can be passed to FlowSampler + run_kwds: list + List of keyword arguments that can be passed to FlowSampler.run + """ + return nessai.utils.settings.get_all_kwargs( + importance_nested_sampler=importance_nested_sampler, + split_kwargs=True, + ) + + @classmethod + def from_config( + cls, cp, model, output_file=None, nprocesses=1, use_mpi=False + ): + """ + Loads the sampler from the given config file. + """ + section = "sampler" + # check name + assert ( + cp.get(section, "name") == cls.name + ), "name in section [sampler] must match mine" + + if cp.has_option(section, "importance_nested_sampler"): + importance_nested_sampler = cp.get( + section, + "importance_nested_sampler", + ) + else: + importance_nested_sampler = False + + # Requires additional development work, see the model class below + if importance_nested_sampler is True: + raise NotImplementedError( + "Importance nested sampler is not currently supported" + ) + + default_kwds, default_run_kwds = cls.get_default_kwds( + importance_nested_sampler + ) + + # Keyword arguments the user cannot configure via the config + remove_kwds = [ + "pool", + "n_pool", + "close_pool", + "signal_handling", + "checkpoint_callback", + ] + + for kwd in remove_kwds: + default_kwds.pop(kwd, None) + default_run_kwds.pop(kwd, None) + + kwds = {} + run_kwds = {} + + # ast.literal_eval is used here since specifying a dictionary with all + # various types would be difficult. However, one may wish to revisit + # this in future, e.g. if evaluating code is a concern. + for d_out, d_defaults in zip( + [kwds, run_kwds], [default_kwds, default_run_kwds] + ): + for k in d_defaults.keys(): + if cp.has_option(section, k): + d_out[k] = ast.literal_eval(cp.get(section, k)) + + # Specified kwds + ignore_kwds = {"nlive", "name"} + invalid_kwds = ( + cp[section].keys() + - set().union(kwds.keys(), run_kwds.keys()) + - ignore_kwds + ) + + if invalid_kwds: + raise RuntimeError( + f"Config contains unknown options: {invalid_kwds}" + ) + logging.info("nessai keyword arguments: %s", kwds) + logging.info("nessai run keyword arguments: %s", run_kwds) + + loglikelihood_function = get_optional_arg_from_config( + cp, section, "loglikelihood-function" + ) + + obj = cls( + model, + loglikelihood_function=loglikelihood_function, + nprocesses=nprocesses, + use_mpi=use_mpi, + run_kwds=run_kwds, + extra_kwds=kwds, + ) + + # Do not need to check number of samples for a nested sampler + setup_output(obj, output_file, check_nsamples=False) + if not obj.new_checkpoint: + obj.resume_from_checkpoint() + return obj + + def set_initial_conditions( + self, + initial_distribution=None, + samples_file=None, + ): + """Sets up the starting point for the sampler. + + This is not used for nessai. + """ + + def checkpoint_callback(self, state): + """Callback for checkpointing. + + This will be called periodically by nessai. + """ + for fn in [self.checkpoint_file, self.backup_file]: + with self.io(fn, "a") as fp: + fp.write_pickled_data_into_checkpoint_file(state) + self.write_results(fn) + + def checkpoint(self): + """Checkpoint the sampler""" + self.checkpoint_callback(self._sampler.ns) + + def resume_from_checkpoint(self): + """Reads the resume data from the checkpoint file.""" + try: + with loadfile(self.checkpoint_file, "r") as fp: + self.resume_data = fp.read_pickled_data_from_checkpoint_file() + logging.info( + "Found valid checkpoint file: %s", self.checkpoint_file + ) + except Exception as e: + logging.info("Failed to load checkpoint file with error: %s", e) + + def finalize(self): + """Finalize sampling""" + logz = self._sampler.ns.log_evidence + dlogz = self._sampler.ns.log_evidence_error + logging.info("log Z, dlog Z: %s, %s", logz, dlogz) + self.checkpoint() + + def write_results(self, filename): + """Write the results to a given file. + + Writes the nested samples, log-evidence and log-evidence error. + """ + with self.io(filename, "a") as fp: + fp.write_raw_samples(self.samples) + fp.write_logevidence( + self._sampler.ns.log_evidence, + self._sampler.ns.log_evidence_error, + ) + + +class NessaiModel(nessai.model.Model): + """Wrapper for PyCBC Inference model class for use with nessai. + + Parameters + ---------- + model : inference.BaseModel instance + A model instance from PyCBC. + loglikelihood_function : str + Name of the log-likelihood method to call. + """ + + def __init__(self, model, loglikelihood_function=None): + self.model = model + self.names = list(model.sampling_params) + + # Configure the log-likelihood function + if loglikelihood_function is None: + loglikelihood_function = "loglikelihood" + self.loglikelihood_function = loglikelihood_function + + # Configure the priors bounds + bounds = {} + for dist in model.prior_distribution.distributions: + bounds.update( + **{ + k: [v.min, v.max] + for k, v in dist.bounds.items() + if k in self.names + } + ) + self.bounds = bounds + # Prior and likelihood are not vectorised + self.vectorised_likelihood = False + self.vectorised_prior = False + # Use the pool for computing the prior + self.parallelise_prior = True + + def to_dict(self, x): + """Convert a nessai live point array to a dictionary""" + return {n: x[n].item() for n in self.names} + + def to_live_points(self, x): + """Convert to the structured arrays used by nessai""" + # It is possible this could be made faster + return nessai.livepoint.numpy_array_to_live_points( + rfn.structured_to_unstructured(x), + self.names, + ) + + def new_point(self, N=1): + """Draw a new point""" + return self.to_live_points(self.model.prior_rvs(size=N)) + + def new_point_log_prob(self, x): + """Log-probability for the ``new_point`` method""" + return self.batch_evaluate_log_prior(x) + + def log_prior(self, x): + """Compute the log-prior""" + self.model.update(**self.to_dict(x)) + return self.model.logprior + + def log_likelihood(self, x): + """Compute the log-likelihood""" + self.model.update(**self.to_dict(x)) + return getattr(self.model, self.loglikelihood_function) + + def from_unit_hypercube(self, x): + """Map from the unit-hypercube to the prior.""" + # Needs to be implemented for importance nested sampler + # This method is already available in pycbc but the inverse is not + raise NotImplementedError + + def to_unit_hypercube(self, x): + """Map to the unit-hypercube to the prior.""" + # Needs to be implemented for importance nested sampler + raise NotImplementedError