-
Notifications
You must be signed in to change notification settings - Fork 154
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate PyMC samplers and clean up unsued MCMC sampler (#1053)
* Deleted unused MCMC class, using Pyro's leads to MP errors currently * First try on fixing the pickling error * Moved test class * Test are running * Removed the Pyro-based slice sampler * add draft pymc interface * fix dependency typo * Minor edits * Deleted empty file * improved docstrings, attempt to fix ruff and pyright * improve test * @famura improve tests * Minor test input rework * Docs * Thinning has fallen out of fashion, ask PyMC * Unified warmup_steps default to 200 * More asserts in the test * Bug fix * Improved test_api_posterior_sampler_set * Removed [:num_samples] tuncation * final test improvement and minor cleanup * ignore pyright * Process the new thinning default * Formatting * Formatting * Type fix * Formatting * attempt to fix num sample edge case * Added [:num_samples] back in * update docstring for num_chains * Fixes after merge * Doc fix * Changed tests parameters * Doc updates from review * apply suggested change to thinning docstring Co-authored-by: Jan <[email protected]> * apply suggested change to thinning docstring [2] Co-authored-by: Jan <[email protected]> * expose mc_context argument * fix assumption of default thin value --------- Co-authored-by: felixp8 <[email protected]> Co-authored-by: Felix Pei <[email protected]> Co-authored-by: Jan <[email protected]>
- Loading branch information
1 parent
7be2115
commit 6e0e98a
Showing
18 changed files
with
495 additions
and
1,555 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
from typing import Any, Callable, Optional | ||
|
||
import numpy as np | ||
import pymc | ||
import pytensor.tensor as pt | ||
import torch | ||
from arviz.data import InferenceData | ||
|
||
from sbi.utils import tensor2numpy | ||
|
||
|
||
class PyMCPotential(pt.Op): # type: ignore | ||
"""PyTensor Op wrapping a callable potential function""" | ||
|
||
itypes = [pt.dvector] # expects a vector of parameter values when called | ||
otypes = [ | ||
pt.dscalar, | ||
pt.dvector, | ||
] # outputs a single scalar value (the potential) and gradients for every input | ||
default_output = 0 # return only potential by default | ||
|
||
def __init__( | ||
self, | ||
potential_fn: Callable, | ||
device: str, | ||
track_gradients: bool = True, | ||
): | ||
"""PyTensor Op wrapping a callable potential function for use | ||
with PyMC samplers. | ||
Args: | ||
potential_fn: Potential function that returns a potential given parameters | ||
device: The device to which to move the parameters before evaluation. | ||
track_gradients: Whether to track gradients from potential function | ||
""" | ||
self.potential_fn = potential_fn | ||
self.device = device | ||
self.track_gradients = track_gradients | ||
|
||
def perform(self, node: Any, inputs: Any, outputs: Any) -> None: | ||
"""Compute potential and possibly gradients from input parameters | ||
Args: | ||
node: A "node" that represents the computation, handled internally | ||
by PyTensor. | ||
inputs: A sequence of inputs to the operation of type `itypes`. In this | ||
case, the sequence will contain one array containing the | ||
simulator parameters. | ||
outputs: A sequence allocated for storing operation outputs. In this | ||
case, the sequence will contain one scalar for the computed potential | ||
and an array containing the gradient of the potential with respect | ||
to the simulator parameters. | ||
""" | ||
# unpack and handle inputs | ||
params = inputs[0] | ||
params = ( | ||
torch.tensor(params) | ||
.to(device=self.device, dtype=torch.float32) | ||
.requires_grad_(self.track_gradients) | ||
) | ||
|
||
# call the potential function | ||
energy = self.potential_fn(params, track_gradients=self.track_gradients) | ||
|
||
# output the log-likelihood | ||
outputs[0][0] = tensor2numpy(energy).astype(np.float64) | ||
|
||
# compute and record gradients if desired | ||
if self.track_gradients: | ||
energy.backward() | ||
grads = params.grad | ||
outputs[1][0] = tensor2numpy(grads).astype(np.float64) | ||
else: | ||
outputs[1][0] = np.zeros(params.shape, dtype=np.float64) | ||
|
||
def grad(self, inputs: Any, output_grads: Any) -> list: | ||
"""Get gradients computed from `perform` and return Jacobian-Vector product | ||
Args: | ||
inputs: A sequence of inputs to the operation of type `itypes`. In this | ||
case, the sequence will contain one array containing the | ||
simulator parameters. | ||
output_grads: A sequence of the gradients of the output variables. The first | ||
element will be the gradient of the output of the whole computational | ||
graph with respect to the output of this specific operation, i.e., | ||
the potential. | ||
Returns: | ||
A list containing the gradient of the output of the whole computational | ||
graph with respect to the input of this operation, i.e., | ||
the simulator parameters. | ||
""" | ||
# get outputs from forward pass (but doesn't re-compute it, I think...) | ||
value = self(*inputs) | ||
gradients = value.owner.outputs[1:] # type: ignore | ||
# compute and return JVP | ||
return [(output_grads[0] * grad) for grad in gradients] | ||
|
||
|
||
class PyMCSampler: | ||
"""Interface for PyMC samplers""" | ||
|
||
def __init__( | ||
self, | ||
potential_fn: Callable, | ||
initvals: np.ndarray, | ||
step: str = "nuts", | ||
draws: int = 1000, | ||
tune: int = 1000, | ||
chains: Optional[int] = None, | ||
mp_ctx: str = "spawn", | ||
progressbar: bool = True, | ||
param_name: str = "theta", | ||
device: str = "cpu", | ||
): | ||
"""Interface for PyMC samplers | ||
Args: | ||
potential_fn: Potential function from density estimator. | ||
initvals: Initial parameters. | ||
step: One of `"slice"`, `"hmc"`, or `"nuts"`. | ||
draws: Number of total samples to draw. | ||
tune: Number of tuning steps to take. | ||
chains: Number of MCMC chains to run in parallel. | ||
mp_ctx: Multiprocessing context for parallel sampling. | ||
progressbar: Whether to show/hide progress bars. | ||
param_name: Name for parameter variable, for PyMC and ArviZ structures | ||
device: The device to which to move the parameters for potential_fn. | ||
""" | ||
self.param_name = param_name | ||
self._step = step | ||
self._draws = draws | ||
self._tune = tune | ||
self._initvals = [{self.param_name: iv} for iv in initvals] | ||
self._chains = chains | ||
self._mp_ctx = mp_ctx | ||
self._progressbar = progressbar | ||
self._device = device | ||
|
||
# create PyMC model object | ||
track_gradients = step in (pymc.NUTS, pymc.HamiltonianMC) | ||
self._model = pymc.Model() | ||
potential = PyMCPotential( | ||
potential_fn, track_gradients=track_gradients, device=device | ||
) | ||
with self._model: | ||
params = pymc.Normal( | ||
self.param_name, mu=initvals.mean(axis=0) | ||
) # dummy prior | ||
pymc.Potential("likelihood", potential(params)) # type: ignore | ||
|
||
def run(self) -> np.ndarray: | ||
"""Run MCMC with PyMC | ||
Returns: | ||
MCMC samples | ||
""" | ||
step_class = dict(slice=pymc.Slice, hmc=pymc.HamiltonianMC, nuts=pymc.NUTS) | ||
with self._model: | ||
inference_data = pymc.sample( | ||
step=step_class[self._step](), | ||
tune=self._tune, | ||
draws=self._draws, | ||
initvals=self._initvals, # type: ignore | ||
chains=self._chains, | ||
progressbar=self._progressbar, | ||
mp_ctx=self._mp_ctx, | ||
) | ||
self._inference_data = inference_data | ||
traces = inference_data.posterior # type: ignore | ||
samples = getattr(traces, self.param_name).data | ||
return samples | ||
|
||
def get_samples( | ||
self, num_samples: Optional[int] = None, group_by_chain: bool = True | ||
) -> np.ndarray: | ||
"""Returns samples from last call to self.run. | ||
Raises ValueError if no samples have been generated yet. | ||
Args: | ||
num_samples: Number of samples to return (for each chain if grouped by | ||
chain), if too large, all samples are returned (no error). | ||
group_by_chain: Whether to return samples grouped by chain (chain x samples | ||
x dim_params) or flattened (all_samples, dim_params). | ||
Returns: | ||
samples | ||
""" | ||
if self._inference_data is None: | ||
raise ValueError("No samples found from MCMC run.") | ||
# if not grouped by chain, flatten samples into (all_samples, dim_params) | ||
traces = self._inference_data.posterior # type: ignore | ||
samples = getattr(traces, self.param_name).data | ||
if not group_by_chain: | ||
samples = samples.reshape(-1, samples.shape[-1]) | ||
|
||
# if not specified return all samples | ||
if num_samples is None: | ||
return samples | ||
# otherwise return last num_samples (for each chain when grouped). | ||
elif group_by_chain: | ||
return samples[:, -num_samples:, :] | ||
else: | ||
return samples[-num_samples:, :] | ||
|
||
def get_inference_data(self) -> InferenceData: | ||
"""Returns InferenceData from last call to self.run, | ||
which contains diagnostic information in addition to samples | ||
Raises ValueError if no samples have been generated yet. | ||
Returns: | ||
InferenceData containing samples and sampling run information | ||
""" | ||
if self._inference_data is None: | ||
raise ValueError("No samples found from MCMC run.") | ||
return self._inference_data |
Oops, something went wrong.