forked from gwastro/pycbc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for nessai sampler in pycbc inference (gwastro#4567)
* add basic support for nessai sampler * enable all options and resuming in nessai * fix prior bounds in nessai model * tweak resuming and samples in nessai interface * change outdir to avoid namespace conflicts * tweaks to nessai sampler class * fix nessai checkpointing and other minor tweaks * fix for reading in nessai result files * use callback for checkpointing in nessai * start addressing codeclimate issues * add nessai to auxiliary samplers * add additional comments for nessai * make simple sampler example 2d nessai does not support 1d likelihoods, so this change is neede to test nessai in the CI * fix call to rng.random * add nessai to samplers example and update plot * set minimum version for nessai * force cpu-only version of torch * add missing epsie jump proposal * add plot-marginal to samplers plot * fix whitespace * use lazy formatting in logging functions * move functions to common nested class * update for change common nested class * address more code climate issues
- Loading branch information
1 parent
52821ec
commit 47d5275
Showing
10 changed files
with
483 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,6 @@ ntemps = 4 | |
|
||
[jump_proposal-x] | ||
name = normal | ||
|
||
[jump_proposal-y] | ||
name = normal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[sampler] | ||
name = nessai | ||
nlive = 200 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Provides IO for the nessai sampler""" | ||
import numpy | ||
|
||
from .base_nested_sampler import BaseNestedSamplerFile | ||
|
||
from .posterior import read_raw_samples_from_file | ||
from .dynesty import CommonNestedMetadataIO | ||
|
||
|
||
class NessaiFile(CommonNestedMetadataIO, BaseNestedSamplerFile): | ||
"""Class to handle file IO for the ``nessai`` sampler.""" | ||
|
||
name = "nessai_file" | ||
|
||
def read_raw_samples(self, fields, raw_samples=False, seed=0): | ||
"""Reads samples from a nessai file and constructs a posterior. | ||
Using rejection sampling to resample the nested samples | ||
Parameters | ||
---------- | ||
fields : list of str | ||
The names of the parameters to load. Names must correspond to | ||
dataset names in the file's ``samples`` group. | ||
raw_samples : bool, optional | ||
Return the raw (unweighted) samples instead of the estimated | ||
posterior samples. Default is False. | ||
Returns | ||
------- | ||
dict : | ||
Dictionary of parameter fields -> samples. | ||
""" | ||
samples = read_raw_samples_from_file(self, fields) | ||
logwt = read_raw_samples_from_file(self, ['logwt'])['logwt'] | ||
loglikelihood = read_raw_samples_from_file( | ||
self, ['loglikelihood'])['loglikelihood'] | ||
if not raw_samples: | ||
n_samples = len(logwt) | ||
# Rejection sample | ||
rng = numpy.random.default_rng(seed) | ||
logwt -= logwt.max() | ||
logu = numpy.log(rng.random(n_samples)) | ||
keep = logwt > logu | ||
post = {'loglikelihood': loglikelihood[keep]} | ||
for param in fields: | ||
post[param] = samples[param][keep] | ||
return post | ||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.