Skip to content

Commit

Permalink
Integrate PyMC samplers and clean up unsued MCMC sampler (#1053)
Browse files Browse the repository at this point in the history
* Deleted unused MCMC class, using Pyro's leads to MP errors currently

* First try on fixing the pickling error

* Moved test class

* Test are running

* Removed the Pyro-based slice sampler

* add draft pymc interface

* fix dependency typo

* Minor edits

* Deleted empty file

* improved docstrings, attempt to fix ruff and pyright

* improve test

* @famura improve tests

* Minor test input rework

* Docs

* Thinning has fallen out of fashion, ask PyMC

* Unified warmup_steps default to 200

* More asserts in the test

* Bug fix

* Improved test_api_posterior_sampler_set

* Removed [:num_samples] tuncation

* final test improvement and minor cleanup

* ignore pyright

* Process the new thinning default

* Formatting

* Formatting

* Type fix

* Formatting

* attempt to fix num sample edge case

* Added [:num_samples] back in

* update docstring for num_chains

* Fixes after merge

* Doc fix

* Changed tests parameters

* Doc updates from review

* apply suggested change to thinning docstring

Co-authored-by: Jan <[email protected]>

* apply suggested change to thinning docstring [2]

Co-authored-by: Jan <[email protected]>

* expose mc_context argument

* fix assumption of default thin value

---------

Co-authored-by: felixp8 <[email protected]>
Co-authored-by: Felix Pei <[email protected]>
Co-authored-by: Jan <[email protected]>
  • Loading branch information
4 people authored Apr 3, 2024
1 parent 7be2115 commit 6e0e98a
Show file tree
Hide file tree
Showing 18 changed files with 495 additions and 1,555 deletions.
2 changes: 1 addition & 1 deletion examples/00_HH_simulator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@
"ax.set_xticks([])\n",
"ax.set_yticks([-80, -20, 40])\n",
"\n",
"# plot the injected current \n",
"# plot the injected current\n",
"ax = plt.subplot(gs[1])\n",
"plt.plot(t, I_inj * A_soma * 1e3, \"k\", lw=2)\n",
"plt.xlabel(\"time (ms)\")\n",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"torch>=1.8.0",
"tqdm",
"zuko>=1.0.0",
"pymc>=5.0.0",
]

[project.optional-dependencies]
Expand Down
196 changes: 158 additions & 38 deletions sbi/inference/posteriors/mcmc_posterior.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sbi/samplers/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
resample_given_potential_fn,
sir_init,
)
from sbi.samplers.mcmc.slice import Slice
from sbi.samplers.mcmc.pymc_wrapper import PyMCSampler
from sbi.samplers.mcmc.slice_numpy import (
SliceSampler,
SliceSamplerSerial,
Expand Down
Empty file removed sbi/samplers/mcmc/build_sampler.py
Empty file.
151 changes: 0 additions & 151 deletions sbi/samplers/mcmc/mcmc.py

This file was deleted.

218 changes: 218 additions & 0 deletions sbi/samplers/mcmc/pymc_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from typing import Any, Callable, Optional

import numpy as np
import pymc
import pytensor.tensor as pt
import torch
from arviz.data import InferenceData

from sbi.utils import tensor2numpy


class PyMCPotential(pt.Op): # type: ignore
"""PyTensor Op wrapping a callable potential function"""

itypes = [pt.dvector] # expects a vector of parameter values when called
otypes = [
pt.dscalar,
pt.dvector,
] # outputs a single scalar value (the potential) and gradients for every input
default_output = 0 # return only potential by default

def __init__(
self,
potential_fn: Callable,
device: str,
track_gradients: bool = True,
):
"""PyTensor Op wrapping a callable potential function for use
with PyMC samplers.
Args:
potential_fn: Potential function that returns a potential given parameters
device: The device to which to move the parameters before evaluation.
track_gradients: Whether to track gradients from potential function
"""
self.potential_fn = potential_fn
self.device = device
self.track_gradients = track_gradients

def perform(self, node: Any, inputs: Any, outputs: Any) -> None:
"""Compute potential and possibly gradients from input parameters
Args:
node: A "node" that represents the computation, handled internally
by PyTensor.
inputs: A sequence of inputs to the operation of type `itypes`. In this
case, the sequence will contain one array containing the
simulator parameters.
outputs: A sequence allocated for storing operation outputs. In this
case, the sequence will contain one scalar for the computed potential
and an array containing the gradient of the potential with respect
to the simulator parameters.
"""
# unpack and handle inputs
params = inputs[0]
params = (
torch.tensor(params)
.to(device=self.device, dtype=torch.float32)
.requires_grad_(self.track_gradients)
)

# call the potential function
energy = self.potential_fn(params, track_gradients=self.track_gradients)

# output the log-likelihood
outputs[0][0] = tensor2numpy(energy).astype(np.float64)

# compute and record gradients if desired
if self.track_gradients:
energy.backward()
grads = params.grad
outputs[1][0] = tensor2numpy(grads).astype(np.float64)
else:
outputs[1][0] = np.zeros(params.shape, dtype=np.float64)

def grad(self, inputs: Any, output_grads: Any) -> list:
"""Get gradients computed from `perform` and return Jacobian-Vector product
Args:
inputs: A sequence of inputs to the operation of type `itypes`. In this
case, the sequence will contain one array containing the
simulator parameters.
output_grads: A sequence of the gradients of the output variables. The first
element will be the gradient of the output of the whole computational
graph with respect to the output of this specific operation, i.e.,
the potential.
Returns:
A list containing the gradient of the output of the whole computational
graph with respect to the input of this operation, i.e.,
the simulator parameters.
"""
# get outputs from forward pass (but doesn't re-compute it, I think...)
value = self(*inputs)
gradients = value.owner.outputs[1:] # type: ignore
# compute and return JVP
return [(output_grads[0] * grad) for grad in gradients]


class PyMCSampler:
"""Interface for PyMC samplers"""

def __init__(
self,
potential_fn: Callable,
initvals: np.ndarray,
step: str = "nuts",
draws: int = 1000,
tune: int = 1000,
chains: Optional[int] = None,
mp_ctx: str = "spawn",
progressbar: bool = True,
param_name: str = "theta",
device: str = "cpu",
):
"""Interface for PyMC samplers
Args:
potential_fn: Potential function from density estimator.
initvals: Initial parameters.
step: One of `"slice"`, `"hmc"`, or `"nuts"`.
draws: Number of total samples to draw.
tune: Number of tuning steps to take.
chains: Number of MCMC chains to run in parallel.
mp_ctx: Multiprocessing context for parallel sampling.
progressbar: Whether to show/hide progress bars.
param_name: Name for parameter variable, for PyMC and ArviZ structures
device: The device to which to move the parameters for potential_fn.
"""
self.param_name = param_name
self._step = step
self._draws = draws
self._tune = tune
self._initvals = [{self.param_name: iv} for iv in initvals]
self._chains = chains
self._mp_ctx = mp_ctx
self._progressbar = progressbar
self._device = device

# create PyMC model object
track_gradients = step in (pymc.NUTS, pymc.HamiltonianMC)
self._model = pymc.Model()
potential = PyMCPotential(
potential_fn, track_gradients=track_gradients, device=device
)
with self._model:
params = pymc.Normal(
self.param_name, mu=initvals.mean(axis=0)
) # dummy prior
pymc.Potential("likelihood", potential(params)) # type: ignore

def run(self) -> np.ndarray:
"""Run MCMC with PyMC
Returns:
MCMC samples
"""
step_class = dict(slice=pymc.Slice, hmc=pymc.HamiltonianMC, nuts=pymc.NUTS)
with self._model:
inference_data = pymc.sample(
step=step_class[self._step](),
tune=self._tune,
draws=self._draws,
initvals=self._initvals, # type: ignore
chains=self._chains,
progressbar=self._progressbar,
mp_ctx=self._mp_ctx,
)
self._inference_data = inference_data
traces = inference_data.posterior # type: ignore
samples = getattr(traces, self.param_name).data
return samples

def get_samples(
self, num_samples: Optional[int] = None, group_by_chain: bool = True
) -> np.ndarray:
"""Returns samples from last call to self.run.
Raises ValueError if no samples have been generated yet.
Args:
num_samples: Number of samples to return (for each chain if grouped by
chain), if too large, all samples are returned (no error).
group_by_chain: Whether to return samples grouped by chain (chain x samples
x dim_params) or flattened (all_samples, dim_params).
Returns:
samples
"""
if self._inference_data is None:
raise ValueError("No samples found from MCMC run.")
# if not grouped by chain, flatten samples into (all_samples, dim_params)
traces = self._inference_data.posterior # type: ignore
samples = getattr(traces, self.param_name).data
if not group_by_chain:
samples = samples.reshape(-1, samples.shape[-1])

# if not specified return all samples
if num_samples is None:
return samples
# otherwise return last num_samples (for each chain when grouped).
elif group_by_chain:
return samples[:, -num_samples:, :]
else:
return samples[-num_samples:, :]

def get_inference_data(self) -> InferenceData:
"""Returns InferenceData from last call to self.run,
which contains diagnostic information in addition to samples
Raises ValueError if no samples have been generated yet.
Returns:
InferenceData containing samples and sampling run information
"""
if self._inference_data is None:
raise ValueError("No samples found from MCMC run.")
return self._inference_data
Loading

0 comments on commit 6e0e98a

Please sign in to comment.