Skip to content

Commit

Permalink
Add support for nessai sampler in pycbc inference (gwastro#4567)
Browse files Browse the repository at this point in the history
* add basic support for nessai sampler

* enable all options and resuming in nessai

* fix prior bounds in nessai model

* tweak resuming and samples in nessai interface

* change outdir to avoid namespace conflicts

* tweaks to nessai sampler class

* fix nessai checkpointing and other minor tweaks

* fix for reading in nessai result files

* use callback for checkpointing in nessai

* start addressing codeclimate issues

* add nessai to auxiliary samplers

* add additional comments for nessai

* make simple sampler example 2d

nessai does not support 1d likelihoods, so this change is neede to test nessai in the CI

* fix call to rng.random

* add nessai to samplers example and update plot

* set minimum version for nessai

* force cpu-only version of torch

* add missing epsie jump proposal

* add plot-marginal to samplers plot

* fix whitespace

* use lazy formatting in logging functions

* move functions to common nested class

* update for change common nested class

* address more code climate issues
  • Loading branch information
mj-will authored and bhooshan-gadre committed Dec 19, 2023
1 parent 52821ec commit 47d5275
Show file tree
Hide file tree
Showing 10 changed files with 483 additions and 40 deletions.
4 changes: 4 additions & 0 deletions companion.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ cpnest
pymultinest
ultranest
https://github.com/willvousden/ptemcee/archive/master.tar.gz
# Force the cpu-only version of PyTorch
--extra-index-url https://download.pytorch.org/whl/cpu
torch
nessai>=0.11.0

# useful to look at PyCBC Live with htop
setproctitle
Expand Down
3 changes: 3 additions & 0 deletions examples/inference/samplers/epsie_stub.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ ntemps = 4

[jump_proposal-x]
name = normal

[jump_proposal-y]
name = normal
3 changes: 3 additions & 0 deletions examples/inference/samplers/nessai_stub.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[sampler]
name = nessai
nlive = 200
9 changes: 7 additions & 2 deletions examples/inference/samplers/run.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/sh
for f in cpnest_stub.ini emcee_stub.ini emcee_pt_stub.ini dynesty_stub.ini ultranest_stub.ini epsie_stub.ini; do
for f in cpnest_stub.ini emcee_stub.ini emcee_pt_stub.ini dynesty_stub.ini ultranest_stub.ini epsie_stub.ini nessai_stub.ini; do
echo $f
pycbc_inference \
--config-files `dirname $0`/simp.ini `dirname $0`/$f \
Expand All @@ -16,4 +16,9 @@ dynesty_stub.ini.hdf:dynesty \
ultranest_stub.ini.hdf:ultranest \
epsie_stub.ini.hdf:espie \
cpnest_stub.ini.hdf:cpnest \
--output-file sample.png
nessai_stub.ini.hdf:nessai \
--output-file sample.png \
--plot-contours \
--plot-marginal \
--no-contour-labels \
--no-marginal-titles
6 changes: 6 additions & 0 deletions examples/inference/samplers/simp.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ name = test_normal

[variable_params]
x =
y =

[prior-x]
name = uniform
min-x = -10
max-x = 10

[prior-y]
name = uniform
min-y = -10
max-y = 10
2 changes: 2 additions & 0 deletions pycbc/inference/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .multinest import MultinestFile
from .dynesty import DynestyFile
from .ultranest import UltranestFile
from .nessai import NessaiFile
from .posterior import PosteriorFile
from .txt import InferenceTXTFile

Expand All @@ -49,6 +50,7 @@
DynestyFile.name: DynestyFile,
PosteriorFile.name: PosteriorFile,
UltranestFile.name: UltranestFile,
NessaiFile.name: NessaiFile,
}

try:
Expand Down
76 changes: 38 additions & 38 deletions pycbc/inference/io/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,44 @@ def extra_args_parser(parser=None, skip_args=None, **kwargs):
"extracted instead.")
return parser, actions

def write_pickled_data_into_checkpoint_file(self, state):
"""Dump the sampler state into checkpoint file
"""
if 'sampler_info/saved_state' not in self:
self.create_group('sampler_info/saved_state')
dump_state(state, self, path='sampler_info/saved_state')

def read_pickled_data_from_checkpoint_file(self):
"""Load the sampler state (pickled) from checkpoint file
"""
return load_state(self, path='sampler_info/saved_state')

def write_raw_samples(self, data, parameters=None):
"""Write the nested samples to the file
"""
if 'samples' not in self:
self.create_group('samples')
write_samples_to_file(self, data, parameters=parameters,
group='samples')
def validate(self):
"""Runs a validation test.
This checks that a samples group exist, and that pickeled data can
be loaded.
Returns
-------
bool :
Whether or not the file is valid as a checkpoint file.
"""
try:
if 'sampler_info/saved_state' in self:
load_state(self, path='sampler_info/saved_state')
checkpoint_valid = True
except KeyError:
checkpoint_valid = False
return checkpoint_valid


class DynestyFile(CommonNestedMetadataIO, BaseNestedSamplerFile):
"""Class to handle file IO for the ``dynesty`` sampler."""

Expand Down Expand Up @@ -148,41 +186,3 @@ def read_raw_samples(self, fields, raw_samples=False, seed=0):
return post
else:
return samples

def write_pickled_data_into_checkpoint_file(self, state):
"""Dump the sampler state into checkpoint file
"""
if 'sampler_info/saved_state' not in self:
self.create_group('sampler_info/saved_state')
dump_state(state, self, path='sampler_info/saved_state')

def read_pickled_data_from_checkpoint_file(self):
"""Load the sampler state (pickled) from checkpoint file
"""
return load_state(self, path='sampler_info/saved_state')

def write_raw_samples(self, data, parameters=None):
"""Write the nested samples to the file
"""
if 'samples' not in self:
self.create_group('samples')
write_samples_to_file(self, data, parameters=parameters,
group='samples')

def validate(self):
"""Runs a validation test.
This checks that a samples group exist, and that pickeled data can
be loaded.
Returns
-------
bool :
Whether or not the file is valid as a checkpoint file.
"""
try:
if 'sampler_info/saved_state' in self:
load_state(self, path='sampler_info/saved_state')
checkpoint_valid = True
except KeyError:
checkpoint_valid = False
return checkpoint_valid
49 changes: 49 additions & 0 deletions pycbc/inference/io/nessai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Provides IO for the nessai sampler"""
import numpy

from .base_nested_sampler import BaseNestedSamplerFile

from .posterior import read_raw_samples_from_file
from .dynesty import CommonNestedMetadataIO


class NessaiFile(CommonNestedMetadataIO, BaseNestedSamplerFile):
"""Class to handle file IO for the ``nessai`` sampler."""

name = "nessai_file"

def read_raw_samples(self, fields, raw_samples=False, seed=0):
"""Reads samples from a nessai file and constructs a posterior.
Using rejection sampling to resample the nested samples
Parameters
----------
fields : list of str
The names of the parameters to load. Names must correspond to
dataset names in the file's ``samples`` group.
raw_samples : bool, optional
Return the raw (unweighted) samples instead of the estimated
posterior samples. Default is False.
Returns
-------
dict :
Dictionary of parameter fields -> samples.
"""
samples = read_raw_samples_from_file(self, fields)
logwt = read_raw_samples_from_file(self, ['logwt'])['logwt']
loglikelihood = read_raw_samples_from_file(
self, ['loglikelihood'])['loglikelihood']
if not raw_samples:
n_samples = len(logwt)
# Rejection sample
rng = numpy.random.default_rng(seed)
logwt -= logwt.max()
logu = numpy.log(rng.random(n_samples))
keep = logwt > logu
post = {'loglikelihood': loglikelihood[keep]}
for param in fields:
post[param] = samples[param][keep]
return post
return samples
7 changes: 7 additions & 0 deletions pycbc/inference/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@
except ImportError:
pass

try:
from .nessai import NessaiSampler
samplers[NessaiSampler.name] = NessaiSampler
except ImportError:
pass


def load_from_config(cp, model, **kwargs):
"""Loads a sampler from the given config file.
Expand Down
Loading

0 comments on commit 47d5275

Please sign in to comment.