diff --git a/examples/00_HH_simulator.ipynb b/examples/00_HH_simulator.ipynb index 269cd3663..3b88cce0b 100644 --- a/examples/00_HH_simulator.ipynb +++ b/examples/00_HH_simulator.ipynb @@ -256,7 +256,7 @@ "ax.set_xticks([])\n", "ax.set_yticks([-80, -20, 40])\n", "\n", - "# plot the injected current \n", + "# plot the injected current\n", "ax = plt.subplot(gs[1])\n", "plt.plot(t, I_inj * A_soma * 1e3, \"k\", lw=2)\n", "plt.xlabel(\"time (ms)\")\n", diff --git a/pyproject.toml b/pyproject.toml index 4b579a0d3..fd0f9c49f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "torch>=1.8.0", "tqdm", "zuko>=1.0.0", + "pymc>=5.0.0", ] [project.optional-dependencies] diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 1079f4602..75d8b1cdb 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -21,7 +21,7 @@ from sbi.inference.potentials.base_potential import BasePotential from sbi.samplers.mcmc import ( IterateParameters, - Slice, + PyMCSampler, SliceSamplerSerial, SliceSamplerVectorized, proposal_init, @@ -46,13 +46,14 @@ def __init__( proposal: Any, theta_transform: Optional[TorchTransform] = None, method: str = "slice_np", - thin: int = 10, - warmup_steps: int = 10, + thin: int = -1, + warmup_steps: int = 200, num_chains: int = 1, init_strategy: str = "resample", init_strategy_parameters: Optional[Dict[str, Any]] = None, init_strategy_num_candidates: Optional[int] = None, num_workers: int = 1, + mp_context: str = "spawn", device: Optional[str] = None, x_shape: Optional[torch.Size] = None, ): @@ -64,14 +65,17 @@ def __init__( theta_transform: Transformation that will be applied during sampling. Allows to perform MCMC in unconstrained space. method: Method used for MCMC sampling, one of `slice_np`, - `slice_np_vectorized`, `slice`, `hmc`, `nuts`. `slice_np` is a custom + `slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`, + `hmc_pymc`, `nuts_pymc`. `slice_np` is a custom numpy implementation of slice sampling. `slice_np_vectorized` is identical to `slice_np`, but if `num_chains>1`, the chains are vectorized for `slice_np_vectorized` whereas they are run sequentially - for `slice_np`. The samplers `hmc`, `nuts` or `slice` sample with Pyro. - thin: The thinning factor for the chain. + for `slice_np`. The samplers ending on `_pyro` are using Pyro, and + likewise the samplers ending on `_pymc` are using PyMC. + thin: The thinning factor for the chain, default 1 (no thinning). warmup_steps: The initial number of samples to discard. - num_chains: The number of chains. + num_chains: The number of chains. Should generally be at most + `num_workers - 1`. init_strategy: The initialisation strategy for chains; `proposal` will draw init locations from `proposal`, whereas `sir` will use Sequential- Importance-Resampling (SIR). SIR initially samples @@ -82,17 +86,32 @@ def __init__( uses `exp(potential_fn)` as weights. init_strategy_parameters: Dictionary of keyword arguments passed to the init strategy, e.g., for `init_strategy=sir` this could be - `num_candidate_samples`, i.e., the number of candidates to to find init + `num_candidate_samples`, i.e., the number of candidates to find init locations (internal default is `1000`), or `device`. - init_strategy_num_candidates: Number of candidates to to find init + init_strategy_num_candidates: Number of candidates to find init locations in `init_strategy=sir` (deprecated, use init_strategy_parameters instead). num_workers: number of cpu cores used to parallelize mcmc + mp_context: Multiprocessing start method, either `"fork"` or `"spawn"` + (default), used by Pyro and PyMC samplers. `"fork"` can be significantly + faster than `"spawn"` but is only supported on POSIX-based systems + (e.g. Linux and macOS, not Windows). device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None, `potential_fn.device` is used. x_shape: Shape of a single simulator output. If passed, it is used to check the shape of the observed data and give a descriptive error. """ + if method == "slice": + warn( + "The Pyro-based slice sampler is deprecated, and the method `slice` " + "has been changed to `slice_np`, i.e., the custom " + "numpy-based slice sampler.", + DeprecationWarning, + stacklevel=2, + ) + method = "slice_np" + + thin = _process_thin_default(thin) super().__init__( potential_fn, @@ -109,6 +128,7 @@ def __init__( self.init_strategy = init_strategy self.init_strategy_parameters = init_strategy_parameters or {} self.num_workers = num_workers + self.mp_context = mp_context self._posterior_sampler = None # Hardcode parameter name to reduce clutter kwargs. self.param_name = "theta" @@ -202,6 +222,7 @@ def sample( mcmc_method: Optional[str] = None, sample_with: Optional[str] = None, num_workers: Optional[int] = None, + mp_context: Optional[str] = None, show_progress_bars: bool = True, ) -> Tensor: r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC. @@ -233,6 +254,7 @@ def sample( num_chains = self.num_chains if num_chains is None else num_chains init_strategy = self.init_strategy if init_strategy is None else init_strategy num_workers = self.num_workers if num_workers is None else num_workers + mp_context = self.mp_context if mp_context is None else mp_context init_strategy_parameters = ( self.init_strategy_parameters if init_strategy_parameters is None @@ -289,7 +311,7 @@ def sample( ) num_samples = torch.Size(sample_shape).numel() - track_gradients = method in ("hmc", "nuts") + track_gradients = method in ("hmc_pyro", "nuts_pyro", "hmc_pymc", "nuts_pymc") with torch.set_grad_enabled(track_gradients): if method in ("slice_np", "slice_np_vectorized"): transformed_samples = self._slice_np_mcmc( @@ -302,7 +324,7 @@ def sample( num_workers=num_workers, show_progress_bars=show_progress_bars, ) - elif method in ("hmc", "nuts", "slice"): + elif method in ("hmc_pyro", "nuts_pyro"): transformed_samples = self._pyro_mcmc( num_samples=num_samples, potential_function=self.potential_, @@ -312,9 +334,22 @@ def sample( warmup_steps=warmup_steps, # type: ignore num_chains=num_chains, show_progress_bars=show_progress_bars, + mp_context=mp_context, + ) + elif method in ("hmc_pymc", "nuts_pymc", "slice_pymc"): + transformed_samples = self._pymc_mcmc( + num_samples=num_samples, + potential_function=self.potential_, + initial_params=initial_params, + mcmc_method=method, # type: ignore + thin=thin, # type: ignore + warmup_steps=warmup_steps, # type: ignore + num_chains=num_chains, + show_progress_bars=show_progress_bars, + mp_context=mp_context, ) else: - raise NameError + raise NameError(f"The sampling method {method} is not implemented!") samples = self.theta_transform.inv(transformed_samples) @@ -452,9 +487,10 @@ def _slice_np_mcmc( num_samples: Desired number of samples. potential_function: A callable **class**. initial_params: Initial parameters for MCMC chain. - thin: Thinning (subsampling) factor. + thin: Thinning (subsampling) factor, default 1 (no thinning). warmup_steps: Initial number of samples to discard. - vectorized: Whether to use a vectorized implementation of the Slice sampler. + vectorized: Whether to use a vectorized implementation of the + `SliceSampler`. num_workers: Number of CPU cores to use. init_width: Inital width of brackets. show_progress_bars: Whether to show a progressbar during sampling; @@ -494,8 +530,7 @@ def _slice_np_mcmc( self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples) # Collect samples from all chains. - samples = samples.reshape(-1, dim_samples)[:num_samples, :] - assert samples.shape[0] == num_samples + samples = samples.reshape(-1, dim_samples)[:num_samples] return samples.type(torch.float32).to(self._device) @@ -504,21 +539,24 @@ def _pyro_mcmc( num_samples: int, potential_function: Callable, initial_params: Tensor, - mcmc_method: str = "slice", - thin: int = 10, + mcmc_method: str = "nuts_pyro", + thin: int = -1, warmup_steps: int = 200, num_chains: Optional[int] = 1, show_progress_bars: bool = True, + mp_context: str = "spawn", ) -> Tensor: - r"""Return samples obtained using Pyro HMC, NUTS for slice kernels. + r"""Return samples obtained using Pyro's HMC or NUTS sampler. Args: num_samples: Desired number of samples. potential_function: A callable **class**. A class, but not a function, is picklable for Pyro MCMC to use it across chains in parallel, even when the potential function requires evaluating a neural network. - mcmc_method: One of `hmc`, `nuts` or `slice`. - thin: Thinning (subsampling) factor. + initial_params: Initial parameters for MCMC chain. + mcmc_method: Pyro MCMC method to use, either `"hmc_pyro"` or + `"nuts_pyro"` (default). + thin: Thinning (subsampling) factor, default 1 (no thinning). warmup_steps: Initial number of samples to discard. num_chains: Whether to sample in parallel. If None, use all but one CPU. show_progress_bars: Whether to show a progressbar during sampling. @@ -526,17 +564,17 @@ def _pyro_mcmc( Returns: Tensor of shape (num_samples, shape_of_single_theta). """ + thin = _process_thin_default(thin) num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains - - kernels = dict(slice=Slice, hmc=HMC, nuts=NUTS) + kernels = dict(hmc_pyro=HMC, nuts_pyro=NUTS) sampler = MCMC( kernel=kernels[mcmc_method](potential_fn=potential_function), - num_samples=(thin * num_samples) // num_chains + num_chains, + num_samples=ceil((thin * num_samples) / num_chains), warmup_steps=warmup_steps, initial_params={self.param_name: initial_params}, num_chains=num_chains, - mp_context="spawn", + mp_context=mp_context, disable_progbar=not show_progress_bars, transforms={}, ) @@ -550,10 +588,66 @@ def _pyro_mcmc( self._posterior_sampler = sampler samples = samples[::thin][:num_samples] - assert samples.shape[0] == num_samples return samples.detach() + def _pymc_mcmc( + self, + num_samples: int, + potential_function: Callable, + initial_params: Tensor, + mcmc_method: str = "nuts_pymc", + thin: int = -1, + warmup_steps: int = 200, + num_chains: Optional[int] = 1, + show_progress_bars: bool = True, + mp_context: str = "spawn", + ) -> Tensor: + r"""Return samples obtained using PyMC's HMC, NUTS or slice samplers. + + Args: + num_samples: Desired number of samples. + potential_function: A callable **class**. A class, but not a function, + is picklable for PyMC MCMC to use it across chains in parallel, + even when the potential function requires evaluating a neural network. + initial_params: Initial parameters for MCMC chain. + mcmc_method: mcmc_method: Pyro MCMC method to use, either `"hmc_pymc"` or + `"slice_pymc"`, or `"nuts_pymc"` (default). + thin: Thinning (subsampling) factor, default 1 (no thinning). + warmup_steps: Initial number of samples to discard. + num_chains: Whether to sample in parallel. If None, use all but one CPU. + show_progress_bars: Whether to show a progressbar during sampling. + + Returns: + Tensor of shape (num_samples, shape_of_single_theta). + """ + thin = _process_thin_default(thin) + num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains + steps = dict(slice_pymc="slice", hmc_pymc="hmc", nuts_pymc="nuts") + + sampler = PyMCSampler( + potential_fn=potential_function, + step=steps[mcmc_method], + initvals=tensor2numpy(initial_params), + draws=ceil((thin * num_samples) / num_chains), + tune=warmup_steps, + chains=num_chains, + mp_ctx=mp_context, + progressbar=show_progress_bars, + param_name=self.param_name, + device=self._device, + ) + samples = sampler.run() + samples = torch.from_numpy(samples).to(dtype=torch.float32, device=self._device) + samples = samples.reshape(-1, initial_params.shape[1]) + + # Save posterior sampler. + self._posterior_sampler = sampler + + samples = samples[::thin][:num_samples] + + return samples + def _prepare_potential(self, method: str) -> Callable: """Combines potential and transform and takes care of gradients and pyro. @@ -563,13 +657,13 @@ def _prepare_potential(self, method: str) -> Callable: Returns: A potential function that is ready to be used in MCMC. """ - if method == "slice": - track_gradients = False - pyro = True - elif method in ("hmc", "nuts"): + if method in ("hmc_pyro", "nuts_pyro"): track_gradients = True pyro = True - elif "slice_np" in method: + elif method in ("hmc_pymc", "nuts_pymc"): + track_gradients = True + pyro = False + elif method in ("slice_np", "slice_np_vectorized", "slice_pymc"): track_gradients = False pyro = False else: @@ -662,8 +756,8 @@ def get_arviz_inference_data(self) -> InferenceData: Note: the InferenceData is constructed using the posterior samples generated in most recent call to `.sample(...)`. - For Pyro HMC and NUTS kernels InferenceData will contain diagnostics, for Pyro - Slice or sbi slice sampling samples, only the samples are added. + For Pyro and PyMC samplers, InferenceData will contain diagnostics, but for + sbi slice samplers, only the samples are added. Returns: inference_data: Arviz InferenceData object. @@ -672,16 +766,20 @@ def get_arviz_inference_data(self) -> InferenceData: self._posterior_sampler is not None ), """No samples have been generated, call .sample() first.""" - sampler: Union[MCMC, SliceSamplerSerial, SliceSamplerVectorized] = ( - self._posterior_sampler - ) + sampler: Union[ + MCMC, SliceSamplerSerial, SliceSamplerVectorized, PyMCSampler + ] = self._posterior_sampler # If Pyro sampler and samples not transformed, use arviz' from_pyro. - # Exclude 'slice' kernel as it lacks the 'divergence' diagnostics key. - if isinstance(self._posterior_sampler, (HMC, NUTS)) and isinstance( + if isinstance(sampler, (HMC, NUTS)) and isinstance( self.theta_transform, torch_tf.IndependentTransform ): inference_data = az.from_pyro(sampler) + # If PyMC sampler and samples not transformed, get cached InferenceData. + elif isinstance(sampler, PyMCSampler) and isinstance( + self.theta_transform, torch_tf.IndependentTransform + ): + inference_data = sampler.get_inference_data() # otherwise get samples from sampler and transform to original space. else: @@ -711,6 +809,28 @@ def get_arviz_inference_data(self) -> InferenceData: return inference_data +def _process_thin_default(thin: int) -> int: + """ + Check if the user did use the default thinning value and raise a warning if so. + + Args: + thin: Thinning (subsampling) factor, setting 1 disables thinning. + + Returns: + The corrected thinning factor. + """ + if thin == -1: + thin = 1 + warn( + "The default value for thinning in MCMC sampling has been changed from " + "10 to 1. This might cause the results differ from the last benchmark.", + UserWarning, + stacklevel=2, + ) + + return thin + + def _maybe_use_dict_entry(default: Any, key: str, dict_to_check: Dict) -> Any: """Returns `default` if `key` is not in the dict and otherwise the dict entry. diff --git a/sbi/samplers/mcmc/__init__.py b/sbi/samplers/mcmc/__init__.py index 7d3abe146..bd62f4436 100644 --- a/sbi/samplers/mcmc/__init__.py +++ b/sbi/samplers/mcmc/__init__.py @@ -4,7 +4,7 @@ resample_given_potential_fn, sir_init, ) -from sbi.samplers.mcmc.slice import Slice +from sbi.samplers.mcmc.pymc_wrapper import PyMCSampler from sbi.samplers.mcmc.slice_numpy import ( SliceSampler, SliceSamplerSerial, diff --git a/sbi/samplers/mcmc/build_sampler.py b/sbi/samplers/mcmc/build_sampler.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/sbi/samplers/mcmc/mcmc.py b/sbi/samplers/mcmc/mcmc.py deleted file mode 100644 index f6e4ba7fb..000000000 --- a/sbi/samplers/mcmc/mcmc.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) 2017-2019 Uber Technologies, Inc. -# SPDX-License-Identifier: Apache-2.0 - -import multiprocessing as mp -import warnings - -from pyro.infer.mcmc import MCMC as BaseMCMC -from pyro.infer.mcmc.api import _MultiSampler, _UnarySampler -from pyro.infer.mcmc.hmc import HMC -from pyro.infer.mcmc.nuts import NUTS - - -class MCMC(BaseMCMC): - """ - Identical to Pyro's MCMC class except for `available_cpu` parameter. - - Wrapper class for Markov Chain Monte Carlo algorithms. Specific MCMC algorithms - are TraceKernel instances and need to be supplied as a ``kernel`` argument - to the constructor. - - .. note:: The case of `num_chains > 1` uses python multiprocessing to - run parallel chains in multiple processes. This goes with the usual - caveats around multiprocessing in python, e.g. the model used to - initialize the ``kernel`` must be serializable via `pickle`, and the - performance / constraints will be platform dependent (e.g. only - the "spawn" context is available in Windows). This has also not - been extensively tested on the Windows platform. - - :param kernel: An instance of the ``TraceKernel`` class, which when - given an execution trace returns another sample trace from the target - (posterior) distribution. - :param int num_samples: The number of samples that need to be generated, - excluding the samples discarded during the warmup phase. - :param int warmup_steps: Number of warmup iterations. The samples generated - during the warmup phase are discarded. If not provided, default is - half of `num_samples`. - :param int num_chains: Number of MCMC chains to run in parallel. Depending on - whether `num_chains` is 1 or more than 1, this class internally dispatches - to either `_UnarySampler` or `_MultiSampler`. - :param dict initial_params: dict containing initial tensors in unconstrained - space to initiate the markov chain. The leading dimension's size must match - that of `num_chains`. If not specified, parameter values will be sampled from - the prior. - :param hook_fn: Python callable that takes in `(kernel, samples, stage, i)` - as arguments. stage is either `sample` or `warmup` and i refers to the - i'th sample for the given stage. This can be used to implement additional - logging, or more generally, run arbitrary code per generated sample. - :param str mp_context: Multiprocessing context to use when `num_chains > 1`. - Only applicable for Python 3.5 and above. Use `mp_context="spawn"` for - CUDA. - :param bool disable_progbar: Disable progress bar and diagnostics update. - :param bool disable_validation: Disables distribution validation check. This is - disabled by default, since divergent transitions will lead to exceptions. - Switch to `True` for debugging purposes. - :param dict transforms: dictionary that specifies a transform for a sample site - with constrained support to unconstrained space. - :param int available_cpu: Number of available CPUs, defaults to `mp.cpu_count()-1`. - Setting it to 1 disables multiprocessing. - """ - - def __init__( - self, - kernel, - num_samples, - warmup_steps=None, - initial_params=None, - num_chains=1, - hook_fn=None, - mp_context=None, - disable_progbar=False, - disable_validation=True, - transforms=None, - available_cpu=mp.cpu_count() - 1, - ): - self.warmup_steps = ( - num_samples if warmup_steps is None else warmup_steps - ) # Stan - self.num_samples = num_samples - self.kernel = kernel - self.transforms = transforms - self.disable_validation = disable_validation - self._samples = None - self._args = None - self._kwargs = None - if ( - isinstance(self.kernel, (HMC, NUTS)) - and self.kernel.potential_fn is not None - and initial_params is None - ): - raise ValueError( - "Must provide valid initial parameters to begin sampling" - " when using `potential_fn` in HMC/NUTS kernel." - ) - parallel = False - if num_chains > 1: - # check that initial_params is different for each chain - if initial_params: - for v in initial_params.values(): - if v.shape[0] != num_chains: - raise ValueError( - "The leading dimension of tensors in `initial_params` " - "must match the number of chains." - ) - # FIXME: probably we want to use "spawn" method by default to avoid the - # error CUDA initialization error - # https://github.com/pytorch/pytorch/issues/2517 even that we run MCMC - # in CPU. - # change multiprocessing context to 'spawn' for CUDA tensors. - if mp_context is None and list(initial_params.values())[0].is_cuda: - mp_context = "spawn" - - # verify num_chains is compatible with available CPU. - available_cpu = max(available_cpu, 1) - if num_chains <= available_cpu: - parallel = True - else: - warnings.warn( - "num_chains={} is more than available_cpu={}. " - "Chains will be drawn sequentially.".format( - num_chains, available_cpu - ), - stacklevel=2, - ) - else: - if initial_params: - initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()} - - self.num_chains = num_chains - self._diagnostics = [None] * num_chains - - if parallel: - self.sampler = _MultiSampler( - kernel, - num_samples, - self.warmup_steps, - num_chains, - mp_context, - disable_progbar, - initial_params=initial_params, - hook=hook_fn, - ) - else: - self.sampler = _UnarySampler( - kernel, - num_samples, - self.warmup_steps, - num_chains, - disable_progbar, - initial_params=initial_params, - hook=hook_fn, - ) diff --git a/sbi/samplers/mcmc/pymc_wrapper.py b/sbi/samplers/mcmc/pymc_wrapper.py new file mode 100644 index 000000000..8d9bcfd4a --- /dev/null +++ b/sbi/samplers/mcmc/pymc_wrapper.py @@ -0,0 +1,218 @@ +from typing import Any, Callable, Optional + +import numpy as np +import pymc +import pytensor.tensor as pt +import torch +from arviz.data import InferenceData + +from sbi.utils import tensor2numpy + + +class PyMCPotential(pt.Op): # type: ignore + """PyTensor Op wrapping a callable potential function""" + + itypes = [pt.dvector] # expects a vector of parameter values when called + otypes = [ + pt.dscalar, + pt.dvector, + ] # outputs a single scalar value (the potential) and gradients for every input + default_output = 0 # return only potential by default + + def __init__( + self, + potential_fn: Callable, + device: str, + track_gradients: bool = True, + ): + """PyTensor Op wrapping a callable potential function for use + with PyMC samplers. + + Args: + potential_fn: Potential function that returns a potential given parameters + device: The device to which to move the parameters before evaluation. + track_gradients: Whether to track gradients from potential function + """ + self.potential_fn = potential_fn + self.device = device + self.track_gradients = track_gradients + + def perform(self, node: Any, inputs: Any, outputs: Any) -> None: + """Compute potential and possibly gradients from input parameters + + Args: + node: A "node" that represents the computation, handled internally + by PyTensor. + inputs: A sequence of inputs to the operation of type `itypes`. In this + case, the sequence will contain one array containing the + simulator parameters. + outputs: A sequence allocated for storing operation outputs. In this + case, the sequence will contain one scalar for the computed potential + and an array containing the gradient of the potential with respect + to the simulator parameters. + """ + # unpack and handle inputs + params = inputs[0] + params = ( + torch.tensor(params) + .to(device=self.device, dtype=torch.float32) + .requires_grad_(self.track_gradients) + ) + + # call the potential function + energy = self.potential_fn(params, track_gradients=self.track_gradients) + + # output the log-likelihood + outputs[0][0] = tensor2numpy(energy).astype(np.float64) + + # compute and record gradients if desired + if self.track_gradients: + energy.backward() + grads = params.grad + outputs[1][0] = tensor2numpy(grads).astype(np.float64) + else: + outputs[1][0] = np.zeros(params.shape, dtype=np.float64) + + def grad(self, inputs: Any, output_grads: Any) -> list: + """Get gradients computed from `perform` and return Jacobian-Vector product + + Args: + inputs: A sequence of inputs to the operation of type `itypes`. In this + case, the sequence will contain one array containing the + simulator parameters. + output_grads: A sequence of the gradients of the output variables. The first + element will be the gradient of the output of the whole computational + graph with respect to the output of this specific operation, i.e., + the potential. + + Returns: + A list containing the gradient of the output of the whole computational + graph with respect to the input of this operation, i.e., + the simulator parameters. + """ + # get outputs from forward pass (but doesn't re-compute it, I think...) + value = self(*inputs) + gradients = value.owner.outputs[1:] # type: ignore + # compute and return JVP + return [(output_grads[0] * grad) for grad in gradients] + + +class PyMCSampler: + """Interface for PyMC samplers""" + + def __init__( + self, + potential_fn: Callable, + initvals: np.ndarray, + step: str = "nuts", + draws: int = 1000, + tune: int = 1000, + chains: Optional[int] = None, + mp_ctx: str = "spawn", + progressbar: bool = True, + param_name: str = "theta", + device: str = "cpu", + ): + """Interface for PyMC samplers + + Args: + potential_fn: Potential function from density estimator. + initvals: Initial parameters. + step: One of `"slice"`, `"hmc"`, or `"nuts"`. + draws: Number of total samples to draw. + tune: Number of tuning steps to take. + chains: Number of MCMC chains to run in parallel. + mp_ctx: Multiprocessing context for parallel sampling. + progressbar: Whether to show/hide progress bars. + param_name: Name for parameter variable, for PyMC and ArviZ structures + device: The device to which to move the parameters for potential_fn. + """ + self.param_name = param_name + self._step = step + self._draws = draws + self._tune = tune + self._initvals = [{self.param_name: iv} for iv in initvals] + self._chains = chains + self._mp_ctx = mp_ctx + self._progressbar = progressbar + self._device = device + + # create PyMC model object + track_gradients = step in (pymc.NUTS, pymc.HamiltonianMC) + self._model = pymc.Model() + potential = PyMCPotential( + potential_fn, track_gradients=track_gradients, device=device + ) + with self._model: + params = pymc.Normal( + self.param_name, mu=initvals.mean(axis=0) + ) # dummy prior + pymc.Potential("likelihood", potential(params)) # type: ignore + + def run(self) -> np.ndarray: + """Run MCMC with PyMC + + Returns: + MCMC samples + """ + step_class = dict(slice=pymc.Slice, hmc=pymc.HamiltonianMC, nuts=pymc.NUTS) + with self._model: + inference_data = pymc.sample( + step=step_class[self._step](), + tune=self._tune, + draws=self._draws, + initvals=self._initvals, # type: ignore + chains=self._chains, + progressbar=self._progressbar, + mp_ctx=self._mp_ctx, + ) + self._inference_data = inference_data + traces = inference_data.posterior # type: ignore + samples = getattr(traces, self.param_name).data + return samples + + def get_samples( + self, num_samples: Optional[int] = None, group_by_chain: bool = True + ) -> np.ndarray: + """Returns samples from last call to self.run. + + Raises ValueError if no samples have been generated yet. + + Args: + num_samples: Number of samples to return (for each chain if grouped by + chain), if too large, all samples are returned (no error). + group_by_chain: Whether to return samples grouped by chain (chain x samples + x dim_params) or flattened (all_samples, dim_params). + + Returns: + samples + """ + if self._inference_data is None: + raise ValueError("No samples found from MCMC run.") + # if not grouped by chain, flatten samples into (all_samples, dim_params) + traces = self._inference_data.posterior # type: ignore + samples = getattr(traces, self.param_name).data + if not group_by_chain: + samples = samples.reshape(-1, samples.shape[-1]) + + # if not specified return all samples + if num_samples is None: + return samples + # otherwise return last num_samples (for each chain when grouped). + elif group_by_chain: + return samples[:, -num_samples:, :] + else: + return samples[-num_samples:, :] + + def get_inference_data(self) -> InferenceData: + """Returns InferenceData from last call to self.run, + which contains diagnostic information in addition to samples + + Raises ValueError if no samples have been generated yet. + + Returns: + InferenceData containing samples and sampling run information + """ + if self._inference_data is None: + raise ValueError("No samples found from MCMC run.") + return self._inference_data diff --git a/sbi/samplers/mcmc/slice.py b/sbi/samplers/mcmc/slice.py deleted file mode 100644 index 3359d3fe0..000000000 --- a/sbi/samplers/mcmc/slice.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) 2017-2019 Uber Technologies, Inc. -# SPDX-License-Identifier: Apache-2.0 - -from copy import deepcopy -from typing import Callable, Dict, Optional - -import torch -from pyro.infer.mcmc.mcmc_kernel import MCMCKernel -from pyro.infer.mcmc.util import initialize_model -from torch import Tensor - - -class Slice(MCMCKernel): - def __init__( - self, - model: Optional[Callable] = None, - potential_fn: Optional[Callable] = None, - initial_width: float = 0.01, - max_width=float("inf"), - transforms: Optional[Dict] = None, - max_plate_nesting: Optional[int] = None, - jit_compile: bool = False, - jit_options: Optional[Dict] = None, - ignore_jit_warnings: bool = False, - ) -> None: - """ - Slice sampling kernel [1]. - - During the warmup phase, the width of the bracket is adapted, starting from - the provided initial width. - - **References** - - [1] `Slice Sampling `_, - Radford M. Neal - - Args: - model: Python callable containing Pyro primitives. - potential_fn: Python callable calculating potential energy with input - is a dict of real support parameters. - initial_width: Initial bracket width - max_width: Maximum bracket width - transforms: Optional dictionary that specifies a transform - for a sample site with constrained support to unconstrained space. The - transform should be invertible, and implement `log_abs_det_jacobian`. - If not specified and the model has sites with constrained support, - automatic transformations will be applied, as specified in - :mod:`torch.distributions.constraint_registry`. - max_plate_nesting: Optional bound on max number of nested - :func:`pyro.plate` contexts. This is required if model contains - discrete sample sites that can be enumerated over in parallel. - jit_compile: Optional parameter denoting whether to use - the PyTorch JIT to trace the log density computation, and use this - optimized executable trace in the integrator. - jit_options: A dictionary contains optional arguments for - :func:`torch.jit.trace` function. - ignore_jit_warnings: Flag to ignore warnings from the JIT - tracer when ``jit_compile=True``. Default is False. - """ - if not ((model is None) ^ (potential_fn is None)): - raise ValueError("Only one of `model` or `potential_fn` must be specified.") - # NB: deprecating args - model, transforms - self.model = model - self.transforms = transforms - self._max_plate_nesting = max_plate_nesting - self._jit_compile = jit_compile - self._jit_options = jit_options - self._ignore_jit_warnings = ignore_jit_warnings - - self.potential_fn = potential_fn - - self._initial_width = initial_width - self._max_width = max_width - - self._reset() - - super(Slice, self).__init__() - - def _reset(self): - self._t = 0 - self._width: Optional[Tensor] = None - self._num_dimensions: Optional[int] = None - self._initial_params: Optional[Dict] = None - self._site_name = None - - def setup(self, warmup_steps, *args, **kwargs): - self._warmup_steps = warmup_steps - if self.model is not None: - self._initialize_model_properties(args, kwargs) - - # TODO: Clean up required for multiple sites - assert self.initial_params is not None - self._site_name = next(iter(self.initial_params.keys())) - self._num_dimensions = next(iter(self.initial_params.values())).shape[-1] - - self._width = torch.full((self._num_dimensions,), self._initial_width) - - @property - def initial_params(self): - return deepcopy(self._initial_params) - - @initial_params.setter - def initial_params(self, params): - assert ( - isinstance(params, dict) and len(params) == 1 - ), "Slice sampling only implemented for a single site." # TODO: Implement - self._initial_params = params - - def _initialize_model_properties(self, model_args, model_kwargs): - init_params, potential_fn, transforms, trace = initialize_model( - self.model, - model_args, - model_kwargs, - transforms=self.transforms, - max_plate_nesting=self._max_plate_nesting, - jit_compile=self._jit_compile, - jit_options=self._jit_options, - skip_jit_warnings=self._ignore_jit_warnings, - ) - self.potential_fn = potential_fn - self.transforms = transforms - if self._initial_params is None: - self.initial_params = init_params - self._prototype_trace = trace - - def cleanup(self): - self._reset() - - def sample(self, params): - assert ( - self._num_dimensions is not None and self._width is not None - ), "Chain not initialized." - - for dim in torch.randperm(self._num_dimensions): - # cast for pyright. - idx = int(dim.item()) - ( - params[self._site_name].view(-1)[idx], - width_d, - ) = self._sample_from_conditional(params, idx) - if self._t < self._warmup_steps: - # TODO: Other schemes for tuning bracket width? - self._width[idx] += (width_d.item() - self._width[idx]) / (self._t + 1) - - self._t += 1 - - return params.copy() - - def _sample_from_conditional(self, params, dim): - # TODO: Flag for doubling and stepping out procedures, see Neal paper, and also: - # https://pints.readthedocs.io/en/latest/mcmc_samplers/slice_doubling_mcmc.html - # https://pints.readthedocs.io/en/latest/mcmc_samplers/slice_stepout_mcmc.html - - def _log_prob_d(x): - assert self.potential_fn is not None, "Chain not initialized." - - return -self.potential_fn({ - self._site_name: torch.cat(( - params[self._site_name].view(-1)[:dim], - x.reshape(1), - params[self._site_name].view(-1)[dim + 1 :], - )).unsqueeze( - 0 - ) # TODO: The unsqueeze seems to give a speed up, figure out when - # this is the case exactly - }) - - assert ( - self._site_name is not None and self._width is not None - ), "Chain not initialized." - - # Sample uniformly from slice - log_height = _log_prob_d(params[self._site_name].view(-1)[dim]) + torch.log( - torch.rand(1, device=params[self._site_name].device) - ) - - # Position the bracket randomly around the current sample - lower = params[self._site_name].view(-1)[dim] - self._width[dim] * torch.rand( - 1, device=params[self._site_name].device - ) - upper = lower + self._width[dim] - - # Find lower bracket end - while ( - _log_prob_d(lower) >= log_height - and params[self._site_name].view(-1)[dim] - lower < self._max_width - ): - lower -= self._width[dim] - - # Find upper bracket end - while ( - _log_prob_d(upper) >= log_height - and upper - params[self._site_name].view(-1)[dim] < self._max_width - ): - upper += self._width[dim] - - # Sample uniformly from bracket - new_parameter = (upper - lower) * torch.rand( - 1, device=params[self._site_name].device - ) + lower - - # If outside slice, reject sample and shrink bracket - while _log_prob_d(new_parameter) < log_height: - if new_parameter < params[self._site_name].view(-1)[dim]: - lower = new_parameter - else: - upper = new_parameter - new_parameter = (upper - lower) * torch.rand( - 1, device=params[self._site_name].device - ) + lower - - return new_parameter, upper - lower diff --git a/sbi/samplers/mcmc/slice_numpy.py b/sbi/samplers/mcmc/slice_numpy.py index ea7733375..605120118 100644 --- a/sbi/samplers/mcmc/slice_numpy.py +++ b/sbi/samplers/mcmc/slice_numpy.py @@ -22,20 +22,20 @@ class MCMCSampler: Superclass for MCMC samplers. """ - def __init__(self, x, lp_f: Callable, thin: Optional[int], verbose: bool = False): + def __init__(self, x, lp_f: Callable, thin: int, verbose: bool = False): """ Args: x: initial state lp_f: Function that returns the log prob. - thin: amount of thinning; if None, no thinning. + thin: Thinning (subsampling) factor, default 1 (no thinning). verbose: Whether to show progress bars (False). """ self.x = np.array(x, dtype=float) self.lp_f = lp_f self.L = lp_f(self.x) - self.thin = 1 if thin is None else thin + self.thin = thin self.n_dims = self.x.size if self.x.ndim == 1 else self.x.shape[1] self.verbose = verbose @@ -61,7 +61,7 @@ def __init__( lp_f, max_width=float("inf"), init_width: Union[float, np.ndarray] = 0.01, - thin=None, + thin=1, tuning: int = 50, verbose: bool = False, ): @@ -222,7 +222,7 @@ def __init__( log_prob_fn: Callable, init_params: np.ndarray, num_chains: int = 1, - thin: Optional[int] = None, + thin: int = 1, tuning: int = 50, verbose: bool = True, init_width: Union[float, np.ndarray] = 0.01, @@ -237,7 +237,7 @@ def __init__( log_prob_fn: Log prob function. init_params: Initial parameters. num_chains: Number of MCMC chains to run in parallel - thin: amount of thinning; if None, no thinning. + thin: Thinning (subsampling) factor, default 1 (no thinning). tuning: Number of tuning steps for brackets. verbose: Show/hide additional info such as progress bars. init_width: Inital width of brackets. @@ -356,7 +356,7 @@ def __init__( log_prob_fn: Callable, init_params: np.ndarray, num_chains: int = 1, - thin: Optional[int] = None, + thin: int = 1, tuning: int = 50, verbose: bool = True, init_width: Union[float, np.ndarray] = 0.01, @@ -369,7 +369,7 @@ def __init__( log_prob_fn: Log prob function. init_params: Initial parameters. num_chains: Number of MCMC chains to run in parallel - thin: amount of thinning; if None, no thinning. + thin: Thinning (subsampling) factor, default 1 (no thinning). tuning: Number of tuning steps for brackets. verbose: Show/hide additional info such as progress bars. init_width: Inital width of brackets. diff --git a/tests/mcmc_slice_pyro/LICENSE.md b/tests/mcmc_slice_pyro/LICENSE.md deleted file mode 100644 index d64569567..000000000 --- a/tests/mcmc_slice_pyro/LICENSE.md +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/tests/mcmc_slice_pyro/__init__.py b/tests/mcmc_slice_pyro/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/mcmc_slice_pyro/common.py b/tests/mcmc_slice_pyro/common.py deleted file mode 100644 index 4d082ed45..000000000 --- a/tests/mcmc_slice_pyro/common.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright (c) 2017-2019 Uber Technologies, Inc. -# SPDX-License-Identifier: Apache-2.0 - -import contextlib -import numbers -import os -import shutil -import tempfile -import warnings -from itertools import product - -import numpy as np -import pytest -import torch -import torch.cuda -from numpy.testing import assert_allclose -from pytest import approx - -""" -Contains test utilities for assertions, approximate comparison (of tensors and other -objects). - -Code has been largely adapted from pytorch/test/common.py Source: -https://github.com/pytorch/pytorch/blob/master/test/common.py -""" - -TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) -RESOURCE_DIR = os.path.join(TESTS_DIR, "resources") -EXAMPLES_DIR = os.path.join(os.path.dirname(TESTS_DIR), "examples") - - -def xfail_param(*args, **kwargs): - return pytest.param(*args, marks=[pytest.mark.xfail(**kwargs)]) - - -def skipif_param(*args, **kwargs): - return pytest.param(*args, marks=[pytest.mark.skipif(**kwargs)]) - - -def suppress_warnings(fn): - def wrapper(*args, **kwargs): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - fn(*args, **kwargs) - - return wrapper - - -# backport of Python 3's context manager -@contextlib.contextmanager -def TemporaryDirectory(): - try: - path = tempfile.mkdtemp() - yield path - finally: - if os.path.exists(path): - shutil.rmtree(path) - - -requires_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda is not available" -) - - -def get_cpu_type(t): - assert t.__module__ == "torch.cuda" - return getattr(torch, t.__class__.__name__) - - -def get_gpu_type(t): - assert t.__module__ == "torch" - return getattr(torch.cuda, t.__name__) - - -@contextlib.contextmanager -def tensors_default_to(host): - """ - Context manager to temporarily use Cpu or Cuda tensors in PyTorch. - - :param str host: Either "cuda" or "cpu". - """ - assert host in ("cpu", "cuda"), host - old_module, name = torch.Tensor().type().rsplit(".", 1) - new_module = "torch.cuda" if host == "cuda" else "torch" - torch.set_default_tensor_type("{}.{}".format(new_module, name)) - try: - yield - finally: - torch.set_default_tensor_type("{}.{}".format(old_module, name)) - - -@contextlib.contextmanager -def freeze_rng_state(): - rng_state = torch.get_rng_state() - if torch.cuda.is_available(): - cuda_rng_state = torch.cuda.get_rng_state() - yield - if torch.cuda.is_available(): - torch.cuda.set_rng_state(cuda_rng_state) - torch.set_rng_state(rng_state) - - -@contextlib.contextmanager -def xfail_if_not_implemented(msg="Not implemented"): - try: - yield - except NotImplementedError as e: - pytest.xfail(reason="{}: {}".format(msg, e)) - - -def iter_indices(tensor): - if tensor.dim() == 0: - return range(0) - if tensor.dim() == 1: - return range(tensor.size(0)) - return product(*(range(s) for s in tensor.size())) - - -def is_iterable(obj): - try: - iter(obj) - return True - except BaseException: - return False - - -def assert_tensors_equal(a, b, prec=0.0, msg=""): - assert a.size() == b.size(), msg - if isinstance(prec, numbers.Number) and prec == 0: - assert (a == b).all(), msg - if a.numel() == 0 and b.numel() == 0: - return - b = b.type_as(a) - b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu() - # check that NaNs are in the same locations - nan_mask = a != a - assert torch.equal(nan_mask, b != b), msg - diff = a - b - diff[a == b] = 0 # handle inf - diff[nan_mask] = 0 - if diff.is_signed(): - diff = diff.abs() - if isinstance(prec, torch.Tensor): - assert (diff <= prec).all(), msg - else: - max_err = diff.max().item() - assert max_err <= prec, msg - - -def _safe_coalesce(t): - tc = t.coalesce() - value_map = {} - for idx, val in zip(t._indices().t(), t._values()): - idx_tup = tuple(idx) - if idx_tup in value_map: - value_map[idx_tup] += val - else: - value_map[idx_tup] = val.clone() if torch.is_tensor(val) else val - - new_indices = sorted(list(value_map.keys())) - new_values = [value_map[idx] for idx in new_indices] - if t._values().dim() < 2: - new_values = t._values().new_tensor(new_values) - else: - new_values = torch.stack(new_values) - - new_indices = t._indices().new_tensor(new_indices).t() - tg = t.new(new_indices, new_values, t.size()) - - assert (tc._indices() == tg._indices()).all() - assert (tc._values() == tg._values()).all() - return tg - - -def assert_close(actual, expected, atol=1e-7, rtol=0, msg=""): - if not msg: - msg = "{} vs {}".format(actual, expected) - if isinstance(actual, numbers.Number) and isinstance(expected, numbers.Number): - assert actual == approx(expected, abs=atol, rel=rtol), msg - # Placing this as a second check allows for coercing of numeric types above; - # this can be moved up to harden type checks. - elif type(actual) != type(expected): - raise AssertionError( - "cannot compare {} and {}".format(type(actual), type(expected)) - ) - elif torch.is_tensor(actual) and torch.is_tensor(expected): - prec = atol + rtol * abs(expected) if rtol > 0 else atol - assert actual.is_sparse == expected.is_sparse, msg - if actual.is_sparse: - x = _safe_coalesce(actual) - y = _safe_coalesce(expected) - assert_tensors_equal(x._indices(), y._indices(), prec, msg) - assert_tensors_equal(x._values(), y._values(), prec, msg) - else: - assert_tensors_equal(actual, expected, prec, msg) - elif type(actual) == np.ndarray and type(expected) == np.ndarray: - assert_allclose( - actual, expected, atol=atol, rtol=rtol, equal_nan=True, err_msg=msg - ) - elif isinstance(actual, numbers.Number) and isinstance(y, numbers.Number): - assert actual == approx(expected, abs=atol, rel=rtol), msg - elif isinstance(actual, dict): - assert set(actual.keys()) == set(expected.keys()) - for key, x_val in actual.items(): - assert_close( - x_val, - expected[key], - atol=atol, - rtol=rtol, - msg="At key{}: {} vs {}".format(key, x_val, expected[key]), - ) - elif isinstance(actual, str): - assert actual == expected, msg - elif is_iterable(actual) and is_iterable(expected): - assert len(actual) == len(expected), msg - for xi, yi in zip(actual, expected): - assert_close(xi, yi, atol=atol, rtol=rtol, msg="{} vs {}".format(xi, yi)) - else: - assert actual == expected, msg - - -# TODO: Remove `prec` arg, and move usages to assert_close -def assert_equal(actual, expected, prec=1e-5, msg=""): - if prec > 0.0: - return assert_close(actual, expected, atol=prec, msg=msg) - if not msg: - msg = "{} vs {}".format(actual, expected) - if isinstance(actual, numbers.Number) and isinstance(expected, numbers.Number): - assert actual == expected, msg - # Placing this as a second check allows for coercing of numeric types above; - # this can be moved up to harden type checks. - elif type(actual) != type(expected): - raise AssertionError( - "cannot compare {} and {}".format(type(actual), type(expected)) - ) - elif torch.is_tensor(actual) and torch.is_tensor(expected): - assert actual.is_sparse == expected.is_sparse, msg - if actual.is_sparse: - x = _safe_coalesce(actual) - y = _safe_coalesce(expected) - assert_tensors_equal(x._indices(), y._indices(), msg=msg) - assert_tensors_equal(x._values(), y._values(), msg=msg) - else: - assert_tensors_equal(actual, expected, msg=msg) - elif type(actual) == np.ndarray and type(actual) == np.ndarray: - assert (actual == expected).all(), msg - elif isinstance(actual, dict): - assert set(actual.keys()) == set(expected.keys()) - for key, x_val in actual.items(): - assert_equal( - x_val, - expected[key], - prec=0.0, - msg="At key{}: {} vs {}".format(key, x_val, expected[key]), - ) - elif isinstance(actual, str): - assert actual == expected, msg - elif is_iterable(actual) and is_iterable(expected): - assert len(actual) == len(expected), msg - for xi, yi in zip(actual, expected): - assert_equal(xi, yi, prec=0.0, msg="{} vs {}".format(xi, yi)) - else: - assert actual == expected, msg - - -def assert_not_equal(x, y, prec=1e-5, msg=""): - try: - assert_equal(x, y, prec) - except AssertionError: - return - raise AssertionError( - "{} \nValues are equal: x={}, y={}, prec={}".format(msg, x, y, prec) - ) diff --git a/tests/mcmc_slice_pyro/conftest.py b/tests/mcmc_slice_pyro/conftest.py deleted file mode 100644 index 73ea5794d..000000000 --- a/tests/mcmc_slice_pyro/conftest.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) 2017-2019 Uber Technologies, Inc. -# SPDX-License-Identifier: Apache-2.0 - -import os -import warnings - -import pyro -import pytest -import torch - -torch.set_default_tensor_type(os.environ.get("PYRO_TENSOR_TYPE", "torch.DoubleTensor")) - - -def pytest_configure(config): - config.addinivalue_line( - "markers", "init(rng_seed): initialize the RNG using the seed provided." - ) - config.addinivalue_line( - "markers", "stage(NAME): mark test to run when testing stage matches NAME." - ) - config.addinivalue_line( - "markers", "disable_validation: disable all validation on this test." - ) - - -def pytest_runtest_setup(item): - pyro.clear_param_store() - if item.get_closest_marker("disable_validation"): - pyro.enable_validation(False) - else: - pyro.enable_validation(True) - test_initialize_marker = item.get_closest_marker("init") - if test_initialize_marker: - rng_seed = test_initialize_marker.kwargs["rng_seed"] - pyro.set_rng_seed(rng_seed) - - -def pytest_addoption(parser): - parser.addoption( - "--stage", - action="append", - metavar="NAME", - default=[], - help="Only run tests matching the stage NAME.", - ) - - parser.addoption( - "--lax", - action="store_true", - default=False, - help="Ignore AssertionError when running tests.", - ) - - -def _get_highest_specificity_marker(stage_marker): - """ - Get the most specific stage marker corresponding to the test. Specificity - of test function marker is the highest, followed by test class marker and - module marker. - - :return: List of most specific stage markers for the test. - """ - is_test_collected = False - selected_stages = [] - try: - for marker in stage_marker: - selected_stages = list(marker.args) - is_test_collected = True - break - except TypeError: - selected_stages = list(stage_marker.args) - is_test_collected = True - if not is_test_collected: - raise RuntimeError("stage marker needs at least one stage to be specified.") - return selected_stages - - -def _add_marker(marker, items): - for item in items: - item.add_marker(marker) - - -def pytest_collection_modifyitems(config, items): - test_stages = set(config.getoption("--stage")) - - # add dynamic markers - lax = config.getoption("--lax") - if lax: - _add_marker(pytest.mark.xfail(raises=AssertionError), items) - - # select / deselect tests based on stage criterion - if not test_stages or "all" in test_stages: - return - selected_items = [] - deselected_items = [] - for item in items: - stage_marker = item.get_closest_marker("stage") - if not stage_marker: - selected_items.append(item) - warnings.warn( - f"""No stage associated with the test {item.name}. Will run on - each stage invocation.""", - stacklevel=2, - ) - continue - item_stage_markers = _get_highest_specificity_marker(stage_marker) - if test_stages.isdisjoint(item_stage_markers): - deselected_items.append(item) - else: - selected_items.append(item) - config.hook.pytest_deselected(items=deselected_items) - items[:] = selected_items diff --git a/tests/mcmc_slice_pyro/test_slice.py b/tests/mcmc_slice_pyro/test_slice.py deleted file mode 100644 index 22c1766fc..000000000 --- a/tests/mcmc_slice_pyro/test_slice.py +++ /dev/null @@ -1,515 +0,0 @@ -# Copyright (c) 2017-2019 Uber Technologies, Inc. -# SPDX-License-Identifier: Apache-2.0 - -import logging -import os -from collections import namedtuple - -import pyro -import pytest -import torch -from pyro import distributions as dist -from pyro import optim as optim -from pyro import poutine as poutine -from pyro.contrib.conjugate.infer import ( - BetaBinomialPair, - GammaPoissonPair, - collapse_conjugate, - posterior_replay, -) -from pyro.infer import SVI, TraceEnum_ELBO -from pyro.infer.autoguide import AutoDelta -from pyro.util import ignore_jit_warnings - -from sbi.samplers.mcmc.mcmc import MCMC -from sbi.samplers.mcmc.slice import Slice - -from .common import assert_equal - -# NOTE: Use below imports if this moves upstream -# from tests.common import assert_equal -# from .test_hmc import GaussianChain, rmse - - -class GaussianChain: - def __init__(self, dim, chain_len, num_obs): - self.dim = dim - self.chain_len = chain_len - self.num_obs = num_obs - self.loc_0 = torch.zeros(self.dim) - self.lambda_prec = torch.ones(self.dim) - - def model(self, data): - loc = self.loc_0 - lambda_prec = self.lambda_prec - for i in range(1, self.chain_len + 1): - loc = pyro.sample( - "loc_{}".format(i), dist.Normal(loc=loc, scale=lambda_prec) - ) - pyro.sample("obs", dist.Normal(loc, lambda_prec), obs=data) - - @property - def data(self): - return torch.ones(self.num_obs, self.dim) - - def id_fn(self): - return "dim={}_chain-len={}_num_obs={}".format( - self.dim, self.chain_len, self.num_obs - ) - - -def rmse(t1, t2): - return (t1 - t2).pow(2).mean().sqrt() - - -logger = logging.getLogger(__name__) - - -T = namedtuple( - "TestExample", - [ - "fixture", - "num_samples", - "warmup_steps", - "expected_means", - "expected_precs", - "mean_tol", - "std_tol", - ], -) - -TEST_CASES = [ - T( - GaussianChain(dim=10, chain_len=3, num_obs=1), - num_samples=800, - warmup_steps=200, - expected_means=[0.25, 0.50, 0.75], - expected_precs=[1.33, 1, 1.33], - mean_tol=0.09, - std_tol=0.09, - ), - T( - GaussianChain(dim=10, chain_len=4, num_obs=1), - num_samples=1600, - warmup_steps=200, - expected_means=[0.20, 0.40, 0.60, 0.80], - expected_precs=[1.25, 0.83, 0.83, 1.25], - mean_tol=0.07, - std_tol=0.06, - ), - T( - GaussianChain(dim=5, chain_len=2, num_obs=10000), - num_samples=800, - warmup_steps=200, - expected_means=[0.5, 1.0], - expected_precs=[2.0, 10000], - mean_tol=0.05, - std_tol=0.05, - ), - T( - GaussianChain(dim=5, chain_len=9, num_obs=1), - num_samples=1400, - warmup_steps=200, - expected_means=[0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90], - expected_precs=[1.11, 0.63, 0.48, 0.42, 0.4, 0.42, 0.48, 0.63, 1.11], - mean_tol=0.08, - std_tol=0.08, - ), -] - - -TEST_IDS = [ - t[0].id_fn() if type(t).__name__ == "TestExample" else t[0][0].id_fn() - for t in TEST_CASES -] - - -def mark_jit(*args, **kwargs): - jit_markers = kwargs.pop("marks", []) - jit_markers += [ - pytest.mark.skipif("CI" in os.environ, reason="to reduce running time on CI") - ] - kwargs["marks"] = jit_markers - return pytest.param(*args, **kwargs) - - -def jit_idfn(param): - return "JIT={}".format(param) - - -@pytest.mark.mcmc -@pytest.mark.parametrize( - T._fields, - TEST_CASES, - ids=TEST_IDS, -) -@pytest.mark.skip(reason="Slow test (https://github.com/pytorch/pytorch/issues/12190)") -@pytest.mark.disable_validation() -def test_slice_conjugate_gaussian( - fixture, - num_samples, - warmup_steps, - expected_means, - expected_precs, - mean_tol, - std_tol, -): - pyro.get_param_store().clear() - slice_kernel = Slice(fixture.model) - mcmc = MCMC(slice_kernel, num_samples, warmup_steps, num_chains=3) - mcmc.run(fixture.data) - samples = mcmc.get_samples() - for i in range(1, fixture.chain_len + 1): - param_name = "loc_" + str(i) - latent = samples[param_name] - latent_loc = latent.mean(0) - latent_std = latent.std(0) - expected_mean = torch.ones(fixture.dim) * expected_means[i - 1] - expected_std = 1 / torch.sqrt(torch.ones(fixture.dim) * expected_precs[i - 1]) - - # Actual vs expected posterior means for the latents - logger.debug("Posterior mean (actual) - {}".format(param_name)) - logger.debug(latent_loc) - logger.debug("Posterior mean (expected) - {}".format(param_name)) - logger.debug(expected_mean) - assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol) - - # Actual vs expected posterior precisions for the latents - logger.debug("Posterior std (actual) - {}".format(param_name)) - logger.debug(latent_std) - logger.debug("Posterior std (expected) - {}".format(param_name)) - logger.debug(expected_std) - assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol) - - -@pytest.mark.mcmc -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) -@pytest.mark.parametrize("num_chains", [1, 2]) -def test_logistic_regression(jit, num_chains, mcmc_params_fast: dict): - dim = 3 - data = torch.randn(2000, dim) - true_coefs = torch.arange(1.0, dim + 1.0) - labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() - - def model(data): - coefs_mean = torch.zeros(dim) - coefs = pyro.sample("beta", dist.Normal(coefs_mean, torch.ones(dim))) - y = pyro.sample("y", dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) - return y - - slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True) - mcmc_params_fast["num_chains"] = num_chains - mcmc_params_fast.pop("thin") # thinning is not supported - mcmc = MCMC(slice_kernel, num_samples=500, available_cpu=1, **mcmc_params_fast) - mcmc.run(data) - samples = mcmc.get_samples() - assert_equal(rmse(true_coefs, samples["beta"].mean(0)).item(), 0.0, prec=0.1) - - -@pytest.mark.mcmc -def test_beta_bernoulli(mcmc_params_fast: dict): - def model(data): - alpha = torch.tensor([1.1, 1.1]) - beta = torch.tensor([1.1, 1.1]) - p_latent = pyro.sample("p_latent", dist.Beta(alpha, beta)) - pyro.sample("obs", dist.Bernoulli(p_latent), obs=data) - return p_latent - - true_probs = torch.tensor([0.9, 0.1]) - data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1200,)))) - slice_kernel = Slice(model) - mcmc = MCMC( - slice_kernel, num_samples=400, warmup_steps=mcmc_params_fast["warmup_steps"] - ) - mcmc.run(data) - samples = mcmc.get_samples() - assert_equal(samples["p_latent"].mean(0), true_probs, prec=0.02) - - -@pytest.mark.mcmc -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) -def test_gamma_normal(jit, mcmc_params_fast: dict): - def model(data): - rate = torch.tensor([1.0, 1.0]) - concentration = torch.tensor([1.0, 1.0]) - p_latent = pyro.sample("p_latent", dist.Gamma(rate, concentration)) - pyro.sample("obs", dist.Normal(3, p_latent), obs=data) - return p_latent - - true_std = torch.tensor([0.5, 2]) - data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True) - mcmc = MCMC( - slice_kernel, num_samples=200, warmup_steps=mcmc_params_fast["warmup_steps"] - ) - mcmc.run(data) - samples = mcmc.get_samples() - assert_equal(samples["p_latent"].mean(0), true_std, prec=0.05) - - -@pytest.mark.mcmc -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) -def test_dirichlet_categorical(jit, mcmc_params_fast: dict): - def model(data): - concentration = torch.tensor([1.0, 1.0, 1.0]) - p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration)) - pyro.sample("obs", dist.Categorical(p_latent), obs=data) - return p_latent - - true_probs = torch.tensor([0.1, 0.6, 0.3]) - data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) - slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True) - mcmc = MCMC( - slice_kernel, num_samples=200, warmup_steps=mcmc_params_fast["warmup_steps"] - ) - mcmc.run(data) - samples = mcmc.get_samples() - posterior = samples["p_latent"] - assert_equal(posterior.mean(0), true_probs, prec=0.02) - - -@pytest.mark.mcmc -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) -@pytest.mark.skip(reason="Slice sampling not implemented for multiple sites yet.") -def test_gamma_beta(jit, mcmc_params_fast: dict): - def model(data): - alpha_prior = pyro.sample("alpha", dist.Gamma(concentration=1.0, rate=1.0)) - beta_prior = pyro.sample("beta", dist.Gamma(concentration=1.0, rate=1.0)) - pyro.sample( - "x", - dist.Beta(concentration1=alpha_prior, concentration0=beta_prior), - obs=data, - ) - - true_alpha = torch.tensor(5.0) - true_beta = torch.tensor(1.0) - data = dist.Beta(concentration1=true_alpha, concentration0=true_beta).sample( - torch.Size((5000,)) - ) - slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True) - mcmc = MCMC( - slice_kernel, num_samples=500, warmup_steps=mcmc_params_fast["warmup_steps"] - ) - mcmc.run(data) - samples = mcmc.get_samples() - assert_equal(samples["alpha"].mean(0), true_alpha, prec=0.08) - assert_equal(samples["beta"].mean(0), true_beta, prec=0.05) - - -@pytest.mark.mcmc -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) -@pytest.mark.skip(reason="Slice sampling not implemented for multiple sites yet.") -def test_gaussian_mixture_model(jit, mcmc_params_fast: dict): - K, N = 3, 1000 - - def gmm(data): - mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(K))) - with pyro.plate("num_clusters", K): - cluster_means = pyro.sample( - "cluster_means", dist.Normal(torch.arange(float(K)), 1.0) - ) - with pyro.plate("data", data.shape[0]): - assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) - pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data) - return cluster_means - - true_cluster_means = torch.tensor([1.0, 5.0, 10.0]) - true_mix_proportions = torch.tensor([0.1, 0.3, 0.6]) - cluster_assignments = dist.Categorical(true_mix_proportions).sample( - torch.Size((N,)) - ) - data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() - slice_kernel = Slice( - gmm, max_plate_nesting=1, jit_compile=jit, ignore_jit_warnings=True - ) - mcmc = MCMC( - slice_kernel, num_samples=300, warmup_steps=mcmc_params_fast["warmup_steps"] - ) - mcmc.run(data) - samples = mcmc.get_samples() - assert_equal(samples["phi"].mean(0).sort()[0], true_mix_proportions, prec=0.05) - assert_equal( - samples["cluster_means"].mean(0).sort()[0], true_cluster_means, prec=0.2 - ) - - -@pytest.mark.mcmc -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) -@pytest.mark.skip(reason="Slice sampling not implemented for multiple sites yet.") -def test_bernoulli_latent_model(jit, mcmc_params_fast: dict): - @poutine.broadcast - def model(data): - y_prob = pyro.sample("y_prob", dist.Beta(1.0, 1.0)) - with pyro.plate("data", data.shape[0]): - y = pyro.sample("y", dist.Bernoulli(y_prob)) - z = pyro.sample("z", dist.Bernoulli(0.65 * y + 0.1)) - pyro.sample("obs", dist.Normal(2.0 * z, 1.0), obs=data) - - N = 2000 - y_prob = torch.tensor(0.3) - y = dist.Bernoulli(y_prob).sample(torch.Size((N,))) - z = dist.Bernoulli(0.65 * y + 0.1).sample() - data = dist.Normal(2.0 * z, 1.0).sample() - slice_kernel = Slice( - model, max_plate_nesting=1, jit_compile=jit, ignore_jit_warnings=True - ) - mcmc = MCMC( - slice_kernel, - num_samples=600, - warmup_steps=mcmc_params_fast["warmup_steps"], - num_chains=1, - ) - mcmc.run(data) - samples = mcmc.get_samples() - assert_equal(samples["y_prob"].mean(0), y_prob, prec=0.05) - - -@pytest.mark.mcmc -@pytest.mark.parametrize("num_steps", [2, 3, 30]) -@pytest.mark.skip(reason="Slice sampling not implemented for multiple sites yet.") -def test_gaussian_hmm(num_steps, mcmc_params_fast: dict): - dim = 4 - - def model(data): - initialize = pyro.sample("initialize", dist.Dirichlet(torch.ones(dim))) - with pyro.plate("states", dim): - transition = pyro.sample("transition", dist.Dirichlet(torch.ones(dim, dim))) - emission_loc = pyro.sample( - "emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim)) - ) - emission_scale = pyro.sample( - "emission_scale", dist.LogNormal(torch.zeros(dim), torch.ones(dim)) - ) - x = None - with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): - for t, y in pyro.markov(enumerate(data)): - x = pyro.sample( - "x_{}".format(t), - dist.Categorical(initialize if x is None else transition[x]), - infer={"enumerate": "parallel"}, - ) - pyro.sample( - "y_{}".format(t), - dist.Normal(emission_loc[x], emission_scale[x]), - obs=y, - ) - - def _get_initial_trace(): - guide = AutoDelta( - poutine.block( - model, - expose_fn=lambda msg: not msg["name"].startswith("x") - and not msg["name"].startswith("y"), - ) - ) - elbo = TraceEnum_ELBO(max_plate_nesting=1) - svi = SVI(model, guide, optim.Adam({"lr": 0.01}), elbo) - for _ in range(100): - svi.step(data) - return poutine.trace(guide).get_trace(data) - - def _generate_data(): - transition_probs = torch.rand(dim, dim) - emissions_loc = torch.arange(dim, dtype=torch.Tensor().dtype) - emissions_scale = 1.0 - state = torch.tensor(1) - obs = [dist.Normal(emissions_loc[state], emissions_scale).sample()] - for _ in range(num_steps): - state = dist.Categorical(transition_probs[state]).sample() - obs.append(dist.Normal(emissions_loc[state], emissions_scale).sample()) - return torch.stack(obs) - - data = _generate_data() - slice_kernel = Slice( - model, max_plate_nesting=1, jit_compile=True, ignore_jit_warnings=True - ) - if num_steps == 30: - slice_kernel.initial_trace = _get_initial_trace() - mcmc = MCMC( - slice_kernel, num_samples=5, warmup_steps=mcmc_params_fast["warmup_steps"] - ) - mcmc.run(data) - - -@pytest.mark.mcmc -@pytest.mark.parametrize("hyperpriors", [False, True]) -@pytest.mark.skip(reason="Slice sampling not implemented for multiple sites yet.") -def test_beta_binomial(hyperpriors, mcmc_params_fast: dict): - def model(data): - with pyro.plate("plate_0", data.shape[-1]): - alpha = ( - pyro.sample("alpha", dist.HalfCauchy(1.0)) - if hyperpriors - else torch.tensor([1.0, 1.0]) - ) - beta = ( - pyro.sample("beta", dist.HalfCauchy(1.0)) - if hyperpriors - else torch.tensor([1.0, 1.0]) - ) - beta_binom = BetaBinomialPair() - with pyro.plate("plate_1", data.shape[-2]): - probs = pyro.sample("probs", beta_binom.latent(alpha, beta)) - with pyro.plate("data", data.shape[0]): - pyro.sample( - "binomial", - beta_binom.conditional(probs=probs, total_count=total_count), - obs=data, - ) - - true_probs = torch.tensor([[0.7, 0.4], [0.6, 0.4]]) - total_count = torch.tensor([[1000, 600], [400, 800]]) - num_samples = 80 - data = dist.Binomial(total_count=total_count, probs=true_probs).sample( - sample_shape=(torch.Size((10,))) - ) - hmc_kernel = Slice( - collapse_conjugate(model), jit_compile=True, ignore_jit_warnings=True - ) - mcmc = MCMC( - hmc_kernel, - num_samples=num_samples, - warmup_steps=mcmc_params_fast["warmup_steps"], - ) - mcmc.run(data) - samples = mcmc.get_samples() - posterior = posterior_replay(model, samples, data, num_samples=num_samples) - assert_equal(posterior["probs"].mean(0), true_probs, prec=0.05) - - -@pytest.mark.mcmc -@pytest.mark.parametrize("hyperpriors", [False, True]) -@pytest.mark.skip(reason="Slice sampling not implemented for multiple sites yet.") -def test_gamma_poisson(hyperpriors, mcmc_params_fast: dict): - def model(data): - with pyro.plate("latent_dim", data.shape[1]): - alpha = ( - pyro.sample("alpha", dist.HalfCauchy(1.0)) - if hyperpriors - else torch.tensor([1.0, 1.0]) - ) - beta = ( - pyro.sample("beta", dist.HalfCauchy(1.0)) - if hyperpriors - else torch.tensor([1.0, 1.0]) - ) - gamma_poisson = GammaPoissonPair() - rate = pyro.sample("rate", gamma_poisson.latent(alpha, beta)) - with pyro.plate("data", data.shape[0]): - pyro.sample("obs", gamma_poisson.conditional(rate), obs=data) - - true_rate = torch.tensor([3.0, 10.0]) - num_samples = 100 - data = dist.Poisson(rate=true_rate).sample(sample_shape=(torch.Size((100,)))) - slice_kernel = Slice( - collapse_conjugate(model), jit_compile=True, ignore_jit_warnings=True - ) - mcmc = MCMC( - slice_kernel, - num_samples=num_samples, - warmup_steps=mcmc_params_fast["warmup_steps"], - ) - mcmc.run(data) - samples = mcmc.get_samples() - posterior = posterior_replay(model, samples, data, num_samples=num_samples) - assert_equal(posterior["rate"].mean(0), true_rate, prec=0.3) diff --git a/tests/mcmc_test.py b/tests/mcmc_test.py index d2ebae928..43cc0015b 100644 --- a/tests/mcmc_test.py +++ b/tests/mcmc_test.py @@ -3,7 +3,6 @@ from __future__ import annotations -import arviz as az import numpy as np import pytest import torch @@ -17,6 +16,7 @@ simulate_for_sbi, ) from sbi.neural_nets import likelihood_nn +from sbi.samplers.mcmc.pymc_wrapper import PyMCSampler from sbi.samplers.mcmc.slice_numpy import ( SliceSampler, SliceSamplerSerial, @@ -26,24 +26,20 @@ diagonal_linear_gaussian, true_posterior_linear_gaussian_mvn_prior, ) -from sbi.utils.user_input_checks import ( - process_prior, -) +from sbi.utils.user_input_checks import process_prior from tests.test_utils import check_c2st @pytest.mark.mcmc @pytest.mark.parametrize("num_dim", (1, 2)) -def test_c2st_slice_np_on_Gaussian(num_dim: int): +def test_c2st_slice_np_on_Gaussian( + num_dim: int, warmup: int = 100, num_samples: int = 500 +): """Test MCMC on Gaussian, comparing to ground truth target via c2st. Args: num_dim: parameter dimension of the gaussian model - """ - warmup = 100 - num_samples = 500 - likelihood_shift = -1.0 * ones(num_dim) likelihood_cov = 0.3 * eye(num_dim) prior_mean = zeros(num_dim) @@ -84,7 +80,6 @@ def test_c2st_slice_np_vectorized_parallelized_on_Gaussian( Args: num_dim: parameter dimension of the gaussian model - """ num_samples = 500 warmup = mcmc_params_accurate["warmup_steps"] @@ -135,13 +130,63 @@ def lp_f(x): check_c2st(samples, target_samples, alg=alg) +@pytest.mark.mcmc +@pytest.mark.slow +@pytest.mark.parametrize("step", ("nuts", "hmc", "slice")) +@pytest.mark.parametrize("num_chains", (1, 3)) +def test_c2st_pymc_sampler_on_Gaussian( + step: str, + num_chains: int, + num_dim: int = 2, + num_samples: int = 1000, + warmup: int = 100, +): + """Test PyMC on Gaussian, comparing to ground truth target via c2st.""" + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + x_o = zeros((1, num_dim)) + target_distribution = true_posterior_linear_gaussian_mvn_prior( + x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + target_samples = target_distribution.sample((num_samples,)) + + def lp_f(x, track_gradients=True): + with torch.set_grad_enabled(track_gradients): + return target_distribution.log_prob(x) + + sampler = PyMCSampler( + potential_fn=lp_f, + initvals=np.zeros((num_chains, num_dim)).astype(np.float32), + step=step, + draws=(int(num_samples / num_chains)), # PyMC does not use thinning + tune=warmup, + chains=num_chains, + ) + samples = sampler.run() + assert samples.shape == ( + num_chains, + int(num_samples / num_chains), + num_dim, + ) + samples = samples.reshape(-1, num_dim) + + samples = torch.as_tensor(samples, dtype=torch.float32) + alg = f"pymc_{step}" + + check_c2st(samples, target_samples, alg=alg) + + @pytest.mark.mcmc @pytest.mark.parametrize( "method", ( - "nuts", - "hmc", - "slice", + "nuts_pyro", + "hmc_pyro", + "nuts_pymc", + "hmc_pymc", + "slice_pymc", "slice_np", "slice_np_vectorized", ), @@ -185,4 +230,19 @@ def test_getting_inference_diagnostics(method, mcmc_params_fast: dict): ) idata = posterior.get_arviz_inference_data() - az.plot_trace(idata) + assert hasattr(idata, "posterior"), ( + f"`MCMCPosterior.get_arviz_inference_data()` for method {method} " + f"returned invalid InferenceData. Must contain key 'posterior', " + f"but found only {list(idata.keys())}" + ) + samples = getattr(idata.posterior, posterior.param_name).data + samples = samples.reshape(-1, samples.shape[-1])[:: mcmc_params_fast["thin"]][ + :num_samples + ] + assert samples.shape == ( + num_samples, + num_dim, + ), ( + f"MCMC samples for method {method} have incorrect shape (n_samples, n_dims). " + f"Expected {(num_samples, num_dim)}, got {samples.shape}" + ) diff --git a/tests/posterior_sampler_test.py b/tests/posterior_sampler_test.py index b46a5b317..6b0b59af8 100644 --- a/tests/posterior_sampler_test.py +++ b/tests/posterior_sampler_test.py @@ -5,41 +5,46 @@ import pytest from pyro.infer.mcmc import MCMC -from torch import eye, zeros +from torch import Tensor, eye, zeros from torch.distributions import MultivariateNormal -from sbi import utils as utils from sbi.inference import ( SNL, MCMCPosterior, likelihood_estimator_based_potential, simulate_for_sbi, ) -from sbi.samplers.mcmc import SliceSamplerSerial, SliceSamplerVectorized +from sbi.samplers.mcmc import PyMCSampler, SliceSamplerSerial, SliceSamplerVectorized from sbi.simulators.linear_gaussian import diagonal_linear_gaussian @pytest.mark.mcmc @pytest.mark.parametrize( "sampling_method", - ("slice_np", "slice_np_vectorized", "slice", "nuts", "hmc"), + ( + "slice_np", + "slice_np_vectorized", + "nuts_pyro", + "hmc_pyro", + "nuts_pymc", + "hmc_pymc", + "slice_pymc", + ), ) +@pytest.mark.parametrize("num_chains", (1, pytest.param(3, marks=pytest.mark.slow))) def test_api_posterior_sampler_set( - sampling_method: str, set_seed, mcmc_params_fast: dict + sampling_method: str, + num_chains: int, + set_seed, + mcmc_params_fast: dict, + num_dim: int = 2, + num_samples: int = 42, + num_trials: int = 2, + num_simulations: int = 10, ): - """Runs SNL and checks that posterior_sampler is correctly set. - - Args: - mcmc_method: which mcmc method to use for sampling - set_seed: fixture for manual seeding - """ - - num_dim = 2 - num_samples = 10 - num_trials = 2 - num_simulations = 10 + """Runs SNL and checks that posterior_sampler is correctly set.""" x_o = zeros((num_trials, num_dim)) - # Test for multiple chains is cheap when vectorized. + mcmc_params_fast["num_chains"] = num_chains prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) simulator = diagonal_linear_gaussian @@ -58,17 +63,18 @@ def test_api_posterior_sampler_set( ) assert posterior.posterior_sampler is None - posterior.sample( - sample_shape=(num_samples, mcmc_params_fast["num_chains"]), + samples = posterior.sample( + sample_shape=(num_samples, num_chains), x=x_o, - mcmc_parameters={ - "init_strategy": "prior", - **mcmc_params_fast, - }, + mcmc_parameters={"init_strategy": "prior", **mcmc_params_fast}, ) + assert isinstance(samples, Tensor) + assert samples.shape == (num_samples, num_chains, num_dim) - if sampling_method in ["slice", "hmc", "nuts"]: + if "pyro" in sampling_method: assert type(posterior.posterior_sampler) is MCMC + elif "pymc" in sampling_method: + assert type(posterior.posterior_sampler) is PyMCSampler elif sampling_method == "slice_np": assert type(posterior.posterior_sampler) is SliceSamplerSerial else: # sampling_method == "slice_np_vectorized" diff --git a/tutorials/00_getting_started_flexible.ipynb b/tutorials/00_getting_started_flexible.ipynb index 552bb71fc..b8094b100 100644 --- a/tutorials/00_getting_started_flexible.ipynb +++ b/tutorials/00_getting_started_flexible.ipynb @@ -136,7 +136,7 @@ "metadata": {}, "outputs": [], "source": [ - "inference = SNPE(prior=prior) " + "inference = SNPE(prior=prior)" ] }, { @@ -266,7 +266,7 @@ "outputs": [], "source": [ "theta_true = prior.sample((1,))\n", - "# generate our observation \n", + "# generate our observation\n", "x_obs = simulator(theta_true)" ] }, diff --git a/tutorials/01_gaussian_amortized.ipynb b/tutorials/01_gaussian_amortized.ipynb index 67b9e7833..69f798783 100644 --- a/tutorials/01_gaussian_amortized.ipynb +++ b/tutorials/01_gaussian_amortized.ipynb @@ -183,7 +183,7 @@ "# plot posterior samples\n", "_ = analysis.pairplot(\n", " posterior_samples_1, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(5, 5),\n", - " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"], \n", + " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n", " points=theta_1 # add ground truth thetas\n", ")" ] @@ -238,7 +238,7 @@ "# plot posterior samples\n", "_ = analysis.pairplot(\n", " posterior_samples_2, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(5, 5),\n", - " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"], \n", + " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n", " points=theta_2 # add ground truth thetas\n", ")" ]