diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index f382968cd..8066d3dd1 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -1,7 +1,8 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from typing import Callable, Optional, Tuple +import warnings +from typing import Callable, List, Optional, Tuple import torch from torch import Tensor @@ -115,6 +116,54 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: ) return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore + def condition_on_theta( + self, local_theta: Tensor, dims_global_theta: List[int] + ) -> Callable: + r"""Returns a potential function conditioned on a subset of theta dimensions. + + The goal of this function is to divide the original `theta` into a + `global_theta` we do inference over, and a `local_theta` we condition on (in + addition to conditioning on `x_o`). Thus, the returned potential function will + calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i` + and `local_theta_i` are fixed and `global_theta` varies at inference time. + + Args: + local_theta: The condition values to be conditioned. + dims_global_theta: The indices of the columns in `theta` that will be + sampled, i.e., that *not* conditioned. For example, if original theta + has shape `(batch_dim, 3)`, and `dims_global_theta=[0, 1]`, then the + potential will set `theta[:, 3] = local_theta` at inference time. + + Returns: + A potential function conditioned on the `local_theta`. + """ + + assert self.x_is_iid, "Conditioning is only supported for iid data." + + def conditioned_potential( + theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True + ) -> Tensor: + assert ( + len(dims_global_theta) == theta.shape[1] + ), "dims_global_theta must match the number of parameters to sample." + global_theta = theta[:, dims_global_theta] + x_o = x_o if x_o is not None else self.x_o + # x needs shape (sample_dim (iid), batch_dim (xs), *event_shape) + if x_o.dim() < 3: + x_o = reshape_to_sample_batch_event( + x_o, event_shape=x_o.shape[1:], leading_is_sample=self.x_is_iid + ) + + return _log_likelihood_over_iid_trials_and_local_theta( + x=x_o, + global_theta=global_theta, + local_theta=local_theta, + estimator=self.likelihood_estimator, + track_gradients=track_gradients, + ) + + return conditioned_potential + def _log_likelihoods_over_trials( x: Tensor, @@ -172,6 +221,77 @@ def _log_likelihoods_over_trials( return log_likelihood_trial_sum +def _log_likelihood_over_iid_trials_and_local_theta( + x: Tensor, + global_theta: Tensor, + local_theta: Tensor, + estimator: ConditionalDensityEstimator, + track_gradients: bool = False, +) -> Tensor: + """Returns $\\prod_{i=1}^N \\log(p(x_i|\theta, local_theta_i)$. + + `x` is a batch of iid data, and `local_theta` is a matching batch of condition + values that were part of `theta` but are treated as local iid variables at inference + time. + + This function is different from `_log_likelihoods_over_trials` in that it moves the + iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when + the likelihood estimator is conditioned on a batch of conditions that are iid with + the batch of `x`. It avoids the evaluation of the likelihood for every combination + of `x` and `local_theta`. + + Args: + x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim + holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid + observations. + global_theta: Batch of parameters `(theta_batch_dim, + num_parameters)`. + local_theta: Batch of conditions of shape `(sample_dim, num_local_thetas)`, must + match x's `sample_dim`. + estimator: DensityEstimator. + track_gradients: Whether to track gradients. + + Returns: + log_likelihood: log likelihood for each x in x_batch_dim, for each theta in + theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim, + theta_batch_dim)`. + """ + assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)." + assert ( + local_theta.dim() == 2 + ), "condition must have shape (sample_dim, num_conditions)." + assert global_theta.dim() == 2, "theta must have shape (batch_dim, num_parameters)." + num_trials, num_xs = x.shape[:2] + num_thetas = global_theta.shape[0] + assert ( + local_theta.shape[0] == num_trials + ), "Condition batch size must match the number of iid trials in x." + + # move the iid batch dimension onto the batch dimension of theta and repeat it there + x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1) + + # construct theta and condition to cover all trial-theta combinations + theta_with_condition = torch.cat( + [ + global_theta.repeat(num_trials, 1), # repeat ABAB + local_theta.repeat_interleave(num_thetas, dim=0), # repeat AABB + ], + dim=-1, + ) + + with torch.set_grad_enabled(track_gradients): + # Calculate likelihood in one batch. Returns (1, num_trials * num_theta) + log_likelihood_trial_batch = estimator.log_prob( + x_repeated, condition=theta_with_condition + ) + # Reshape to (x-trials x parameters), sum over trial-log likelihoods. + log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( + num_xs, num_trials, num_thetas + ).sum(1) + + return log_likelihood_trial_sum + + def mixed_likelihood_estimator_based_potential( likelihood_estimator: MixedDensityEstimator, prior: Distribution, @@ -192,6 +312,13 @@ def mixed_likelihood_estimator_based_potential( to unconstrained space. """ + warnings.warn( + "This function is deprecated and will be removed in a future release. Use " + "`likelihood_estimator_based_potential` instead.", + DeprecationWarning, + stacklevel=2, + ) + device = str(next(likelihood_estimator.discrete_net.parameters()).device) potential_fn = MixedLikelihoodBasedPotential( @@ -212,6 +339,13 @@ def __init__( ): super().__init__(likelihood_estimator, prior, x_o, device) + warnings.warn( + "This function is deprecated and will be removed in a future release. Use " + "`LikelihoodBasedPotential` instead.", + DeprecationWarning, + stacklevel=2, + ) + def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: prior_log_prob = self.prior.log_prob(theta) # type: ignore @@ -231,7 +365,6 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: with torch.set_grad_enabled(track_gradients): # Call the specific log prob method of the mixed likelihood estimator as # this optimizes the evaluation of the discrete data part. - # TODO log_prob_iid log_likelihood_trial_batch = self.likelihood_estimator.log_prob( input=x, condition=theta.to(self.device), diff --git a/sbi/inference/trainers/nle/mnle.py b/sbi/inference/trainers/nle/mnle.py index d01ce1e91..83622eaea 100644 --- a/sbi/inference/trainers/nle/mnle.py +++ b/sbi/inference/trainers/nle/mnle.py @@ -7,7 +7,7 @@ from torch.distributions import Distribution from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior -from sbi.inference.potentials import mixed_likelihood_estimator_based_potential +from sbi.inference.potentials import likelihood_estimator_based_potential from sbi.inference.trainers.nle.nle_base import LikelihoodEstimator from sbi.neural_nets.estimators import MixedDensityEstimator from sbi.sbi_types import TensorboardSummaryWriter, TorchModule @@ -155,9 +155,7 @@ def build_posterior( ( potential_fn, theta_transform, - ) = mixed_likelihood_estimator_based_potential( - likelihood_estimator=likelihood_estimator, prior=prior, x_o=None - ) + ) = likelihood_estimator_based_potential(likelihood_estimator, prior, x_o=None) if sample_with == "mcmc": self._posterior = MCMCPosterior( diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py index d6c73b7c9..829f5e1df 100644 --- a/sbi/utils/conditional_density_utils.py +++ b/sbi/utils/conditional_density_utils.py @@ -293,7 +293,7 @@ def __init__( masked outside of prior. """ condition = torch.atleast_2d(condition) - if condition.shape[0] != 1: + if condition.shape[0] > 1: raise ValueError("Condition with batch size > 1 not supported.") self.potential_fn = potential_fn diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index fcb5953d9..fc01d4dbd 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -60,8 +60,8 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) - if num_unique_z < num_unique * (1 - duplicate_tolerance): warnings.warn( - "Z-scoring these simulation outputs resulted in {num_unique_z} unique " - "datapoints. Before z-scoring, it had been {num_unique}. This can " + f"Z-scoring these simulation outputs resulted in {num_unique_z} unique " + f"datapoints. Before z-scoring, it had been {num_unique}. This can " "occur due to numerical inaccuracies when the data covers a large " "range of values. Consider either setting `z_score_x=False` (but " "beware that this can be problematic for training the NN) or exclude " diff --git a/tests/mnle_test.py b/tests/mnle_test.py index a95a2a6ac..099876a3e 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -1,29 +1,32 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +from typing import Union + import pytest import torch from pyro.distributions import InverseGamma -from torch.distributions import Beta, Binomial, Categorical, Gamma +from torch import Tensor +from torch.distributions import Beta, Binomial, Distribution, Gamma from sbi.inference import MNLE, MCMCPosterior from sbi.inference.posteriors.rejection_posterior import RejectionPosterior from sbi.inference.posteriors.vi_posterior import VIPosterior from sbi.inference.potentials.base_potential import BasePotential from sbi.inference.potentials.likelihood_based_potential import ( - MixedLikelihoodBasedPotential, + _log_likelihood_over_iid_trials_and_local_theta, + likelihood_estimator_based_potential, ) from sbi.neural_nets import likelihood_nn from sbi.neural_nets.embedding_nets import FCEmbedding from sbi.utils import BoxUniform, mcmc_transform -from sbi.utils.conditional_density_utils import ConditionedPotential from sbi.utils.torchutils import atleast_2d, process_device from sbi.utils.user_input_checks_utils import MultipleIndependent from tests.test_utils import check_c2st # toy simulator for mixed data -def mixed_simulator(theta, stimulus_condition=2.0): +def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.0): """Simulator for mixed data.""" # Extract parameters beta, ps = theta[:, :1], theta[:, 1:] @@ -37,6 +40,15 @@ def mixed_simulator(theta, stimulus_condition=2.0): return torch.cat((rts, choices), dim=1) +def wrapped_simulator( + theta_and_condition: Tensor, last_idx_parameters: int = 2 +) -> Tensor: + # simulate with experiment conditions + theta = theta_and_condition[:, :last_idx_parameters] + condition = theta_and_condition[:, last_idx_parameters:] + return mixed_simulator(theta, condition) + + @pytest.mark.mcmc @pytest.mark.gpu @pytest.mark.parametrize("device", ("cpu", "gpu")) @@ -190,11 +202,28 @@ def test_mnle_accuracy_with_different_samplers_and_trials( class BinomialGammaPotential(BasePotential): - def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"): + """Binomial-Gamma potential for mixed data.""" + + def __init__( + self, + prior: Distribution, + x_o: Tensor, + concentration_scaling: Union[Tensor, float] = 1.0, + device="cpu", + ): super().__init__(prior, x_o, device) + + # concentration_scaling needs to be a float or match the batch size + if isinstance(concentration_scaling, Tensor): + num_trials = x_o.shape[0] + assert concentration_scaling.shape[0] == num_trials + + # Reshape to match convention (batch_size, num_trials, *event_shape) + concentration_scaling = concentration_scaling.reshape(1, num_trials, -1) + self.concentration_scaling = concentration_scaling - def __call__(self, theta, track_gradients: bool = True): + def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: theta = atleast_2d(theta) with torch.set_grad_enabled(track_gradients): @@ -202,11 +231,12 @@ def __call__(self, theta, track_gradients: bool = True): return iid_ll + self.prior.log_prob(theta) - def iid_likelihood(self, theta): + def iid_likelihood(self, theta: Tensor) -> Tensor: batch_size = theta.shape[0] num_trials = self.x_o.shape[0] theta = theta.reshape(batch_size, 1, -1) beta, rho = theta[:, :, :1], theta[:, :, 1:] + # vectorized logprob_choices = Binomial(probs=rho).log_prob( self.x_o[:, 1:].reshape(1, num_trials, -1) @@ -233,43 +263,44 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): categorical parameter is set to a fixed value (conditioned posterior), and the accuracy of the conditioned posterior is tested against the true posterior. """ - num_simulations = 6000 - num_samples = 500 - - def sim_wrapper(theta): - # simulate with experiment conditions - return mixed_simulator(theta[:, :2], theta[:, 2:] + 1) + num_simulations = 10000 + num_samples = 1000 proposal = MultipleIndependent( [ Gamma(torch.tensor([1.0]), torch.tensor([0.5])), Beta(torch.tensor([2.0]), torch.tensor([2.0])), - Categorical(probs=torch.ones(1, 3)), + BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])), ], validate_args=False, ) theta = proposal.sample((num_simulations,)) - x = sim_wrapper(theta) + x = wrapped_simulator(theta) assert x.shape == (num_simulations, 2) num_trials = 10 - theta_o = proposal.sample((1,)) - theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator. - x_o = sim_wrapper(theta_o.repeat(num_trials, 1)) + theta_and_condition = proposal.sample((num_trials,)) + # use only a single parameter (iid trials) + theta_o = theta_and_condition[:1, :2].repeat(num_trials, 1) + # but different conditions + condition_o = theta_and_condition[:, 2:] + theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1) + + x_o = wrapped_simulator(theta_and_conditions_o) mcmc_kwargs = dict( method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate ) # MNLE - trainer = MNLE(proposal) - estimator = trainer.append_simulations(theta, x).train(training_batch_size=1000) - - potential_fn = MixedLikelihoodBasedPotential(estimator, proposal, x_o) + estimator_fun = likelihood_nn(model="mnle", z_score_x=None) + trainer = MNLE(proposal, estimator_fun) + estimator = trainer.append_simulations(theta, x).train() - conditioned_potential_fn = ConditionedPotential( - potential_fn, condition=theta_o, dims_to_sample=[0, 1] + potential_fn, _ = likelihood_estimator_based_potential(estimator, proposal, x_o) + conditioned_potential_fn = potential_fn.condition_on_theta( + condition_o, dims_global_theta=[0, 1] ) # True posterior samples @@ -283,10 +314,7 @@ def sim_wrapper(theta): prior_transform = mcmc_transform(prior) true_posterior_samples = MCMCPosterior( BinomialGammaPotential( - prior, - atleast_2d(x_o), - concentration_scaling=float(theta_o[0, 2]) - + 1.0, # add one because the sim_wrapper adds one (see above) + prior, atleast_2d(x_o), concentration_scaling=condition_o ), theta_transform=prior_transform, proposal=prior, @@ -303,5 +331,86 @@ def sim_wrapper(theta): check_c2st( cond_samples, true_posterior_samples, - alg=f"MNLE trained with {num_simulations}", + alg=f"MNLE trained with {num_simulations} simulations", + ) + + +@pytest.mark.parametrize("num_thetas", [1, 10]) +@pytest.mark.parametrize("num_trials", [1, 5]) +@pytest.mark.parametrize("num_xs", [1, 3]) +@pytest.mark.parametrize( + "num_conditions", + [ + 1, + pytest.param( + 2, + marks=pytest.mark.xfail( + reason="Batched theta_condition is not " "supported" + ), + ), + ], +) +def test_log_likelihood_over_local_iid_theta( + num_thetas, num_trials, num_xs, num_conditions +): + """Test log likelihood over iid conditions using MNLE. + + Args: + num_thetas: batch of theta to condition on. + num_trials: number of i.i.d. trials in x + num_xs: batch of x, e.g., different subjects in a study. + num_conditions: number of batches of conditions, e.g., different conditions + for each x (not implemented yet). + """ + + # train mnle on mixed data + trainer = MNLE( + density_estimator=likelihood_nn(model="mnle", z_score_x=None), ) + proposal = MultipleIndependent( + [ + Gamma(torch.tensor([1.0]), torch.tensor([0.5])), + Beta(torch.tensor([2.0]), torch.tensor([2.0])), + BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])), + ], + validate_args=False, + ) + + num_simulations = 100 + theta = proposal.sample((num_simulations,)) + x = wrapped_simulator(theta) + estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1) + + # condition on multiple conditions + theta_o = proposal.sample((num_xs,))[:, :2] + + x_o = torch.zeros(num_trials, num_xs, 2) + condition_o = proposal.sample(( + num_conditions, + num_trials, + ))[:, 2:].reshape(num_trials, 1) + for i in range(num_xs): + # simulate with same iid theta but different conditions + x_o[:, i, :] = mixed_simulator(theta_o[i].repeat(num_trials, 1), condition_o) + + # batched conditioning + theta = proposal.sample((num_thetas,))[:, :2] + # x_o has shape (iid, batch, *event) + # condition_o has shape (iid, num_conditions) + ll_batched = _log_likelihood_over_iid_trials_and_local_theta( + x_o, theta, condition_o, estimator + ) + + # looped conditioning + ll_single = [] + for i in range(num_trials): + theta_and_condition = torch.cat( + (theta, condition_o[i].repeat(num_thetas, 1)), dim=1 + ) + x_i = x_o[i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1) + ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition)) + ll_single = torch.stack(ll_single).sum(0) # sum over trials + + assert ll_batched.shape == torch.Size([num_xs, num_thetas]) + assert ll_batched.shape == ll_single.shape + assert torch.allclose(ll_batched, ll_single, atol=1e-5) diff --git a/tutorials/Example_01_DecisionMakingModel.ipynb b/tutorials/Example_01_DecisionMakingModel.ipynb index fcfa10ced..eb16182c2 100644 --- a/tutorials/Example_01_DecisionMakingModel.ipynb +++ b/tutorials/Example_01_DecisionMakingModel.ipynb @@ -73,32 +73,23 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" - ] - } - ], + "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "from pyro.distributions import InverseGamma\n", "from torch import Tensor\n", - "from torch.distributions import Beta, Binomial, Categorical, Gamma\n", + "from torch.distributions import Beta, Binomial, Gamma\n", "\n", "from sbi.analysis import pairplot\n", "from sbi.inference import MNLE, MCMCPosterior\n", - "from sbi.inference.potentials.base_potential import BasePotential\n", - "from sbi.inference.potentials.likelihood_based_potential import (\n", - " MixedLikelihoodBasedPotential,\n", - ")\n", - "from sbi.utils import MultipleIndependent, mcmc_transform\n", - "from sbi.utils.conditional_density_utils import ConditionedPotential\n", + "from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential\n", + "from sbi.neural_nets import likelihood_nn\n", + "from sbi.utils import BoxUniform, MultipleIndependent, mcmc_transform\n", "from sbi.utils.metrics import c2st\n", - "from sbi.utils.torchutils import atleast_2d" + "\n", + "\n", + "from example_01_utils import BinomialGammaPotential" ] }, { @@ -124,44 +115,7 @@ " concentration=concentration_scaling * torch.ones_like(beta), rate=beta\n", " ).sample()\n", "\n", - " return torch.cat((rts, choices), dim=1)\n", - "\n", - "\n", - "# The potential function defines the ground truth likelihood and allows us to\n", - "# obtain reference posterior samples via MCMC.\n", - "class BinomialGammaPotential(BasePotential):\n", - "\n", - " def __init__(self, prior, x_o, concentration_scaling=1.0, device=\"cpu\"):\n", - " super().__init__(prior, x_o, device)\n", - " self.concentration_scaling = concentration_scaling\n", - "\n", - " def __call__(self, theta, track_gradients: bool = True):\n", - " theta = atleast_2d(theta)\n", - "\n", - " with torch.set_grad_enabled(track_gradients):\n", - " iid_ll = self.iid_likelihood(theta)\n", - "\n", - " return iid_ll + self.prior.log_prob(theta)\n", - "\n", - " def iid_likelihood(self, theta):\n", - " batch_size = theta.shape[0]\n", - " num_trials = self.x_o.shape[0]\n", - " theta = theta.reshape(batch_size, 1, -1)\n", - " beta, rho = theta[:, :, :1], theta[:, :, 1:]\n", - " # vectorized\n", - " logprob_choices = Binomial(probs=rho).log_prob(\n", - " self.x_o[:, 1:].reshape(1, num_trials, -1)\n", - " )\n", - "\n", - " logprob_rts = InverseGamma(\n", - " concentration=self.concentration_scaling * torch.ones_like(beta),\n", - " rate=beta,\n", - " ).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1))\n", - "\n", - " joint_likelihood = (logprob_choices + logprob_rts).squeeze()\n", - "\n", - " assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]])\n", - " return joint_likelihood.sum(1)" + " return torch.cat((rts, choices), dim=1)" ] }, { @@ -205,18 +159,10 @@ "execution_count": 5, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/janteusen/qode/sbi/sbi/inference/posteriors/mcmc_posterior.py:115: UserWarning: The default value for thinning in MCMC sampling has been changed from 10 to 1. This might cause the results differ from the last benchmark.\n", - " thin = _process_thin_default(thin)\n" - ] - }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8070275b9eac45d1991d5be41935c145", + "model_id": "92513794bbd148b29b5d60d566338bf6", "version_major": 2, "version_minor": 0 }, @@ -234,6 +180,7 @@ " warmup_steps=50,\n", " method=\"slice_np_vectorized\",\n", " init_strategy=\"proposal\",\n", + " thin=1,\n", ")\n", "\n", "true_posterior = MCMCPosterior(\n", @@ -269,13 +216,13 @@ "name": "stdout", "output_type": "stream", "text": [ - " Neural network successfully converged after 65 epochs." + " Neural network successfully converged after 75 epochs." ] } ], "source": [ "# Training data\n", - "num_simulations = 20000\n", + "num_simulations = 10000\n", "# For training the MNLE emulator we need to define a proposal distribution, the prior is\n", "# a good choice.\n", "proposal = prior\n", @@ -284,7 +231,7 @@ "\n", "# Train MNLE and obtain MCMC-based posterior.\n", "trainer = MNLE()\n", - "estimator = trainer.append_simulations(theta, x).train(training_batch_size=1000)" + "estimator = trainer.append_simulations(theta, x).train()" ] }, { @@ -292,10 +239,18 @@ "execution_count": 7, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/janteusen/qode/sbi/sbi/inference/posteriors/mcmc_posterior.py:115: UserWarning: The default value for thinning in MCMC sampling has been changed from 10 to 1. This might cause the results differ from the last benchmark.\n", + " thin = _process_thin_default(thin)\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1a7792c605404a11a586681fcd3c0a32", + "model_id": "548e67900bd3494481dc61d0f11db250", "version_major": 2, "version_minor": 0 }, @@ -328,7 +283,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -390,7 +345,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fb02120c58a54d029953b4c589f24eca", + "model_id": "21f980fc4f794fe1ab2090ad53a0e323", "version_major": 2, "version_minor": 0 }, @@ -404,7 +359,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1cd3bc58ca8e4a21b1df2812fad8bf45", + "model_id": "418d47f1f4864c089bf68e1a119ebb7d", "version_major": 2, "version_minor": 0 }, @@ -430,7 +385,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -477,7 +432,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "c2st between true and MNLE posterior: 0.593\n" + "c2st between true and MNLE posterior: 0.5155000000000001\n" ] } ], @@ -515,16 +470,26 @@ "metadata": {}, "outputs": [], "source": [ + "# Define a proposal that contains both, priors for the parameters and a discrte\n", + "# prior over experimental conditions.\n", + "proposal = MultipleIndependent(\n", + " [\n", + " Gamma(torch.tensor([1.0]), torch.tensor([0.5])),\n", + " Beta(torch.tensor([2.0]), torch.tensor([2.0])),\n", + " BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),\n", + " ],\n", + " validate_args=False,\n", + ")\n", + "\n", "# define a simulator wrapper in which the experimental condition are contained\n", "# in theta and passed to the simulator.\n", - "def sim_wrapper(theta):\n", + "def sim_wrapper(theta_and_conditions):\n", " # simulate with experiment conditions\n", " return mixed_simulator(\n", " # we assume the first two parameters are beta and rho\n", - " theta=theta[:, :2],\n", + " theta=theta_and_conditions[:, :2],\n", " # we treat the third concentration parameter as an experimental condition\n", - " # add 1 to deal with 0 values from Categorical distribution\n", - " concentration_scaling=theta[:, 2:] + 1,\n", + " concentration_scaling=theta_and_conditions[:, 2:],\n", " )" ] }, @@ -534,17 +499,6 @@ "metadata": {}, "outputs": [], "source": [ - "# Define a proposal that contains both, priors for the parameters and a discrte\n", - "# prior over experimental conditions.\n", - "proposal = MultipleIndependent(\n", - " [\n", - " Gamma(torch.tensor([1.0]), torch.tensor([0.5])),\n", - " Beta(torch.tensor([2.0]), torch.tensor([2.0])),\n", - " Categorical(probs=torch.ones(1, 3)), # 3 discrete conditions\n", - " ],\n", - " validate_args=False,\n", - ")\n", - "\n", "# Simulated data\n", "num_simulations = 10000\n", "num_samples = 1000\n", @@ -554,10 +508,13 @@ "\n", "# simulate observed data and define ground truth parameters\n", "num_trials = 10\n", - "theta_o = proposal.sample((1,))\n", - "theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator.\n", - "# NOTE: we use the same experimental condition for all trials.\n", - "x_o = sim_wrapper(theta_o.repeat(num_trials, 1))" + "# draw one ground truth parameter\n", + "theta_o = proposal.sample((1,))[:, :2]\n", + "# draw num_trials many different conditions\n", + "conditions = proposal.sample((num_trials,))[:, 2:]\n", + "# Theta is repeated for each trial, conditions are different for each trial.\n", + "theta_and_conditions_o = torch.cat((theta_o.repeat(num_trials, 1), conditions), dim=1)\n", + "x_o = sim_wrapper(theta_and_conditions_o)" ] }, { @@ -566,11 +523,15 @@ "source": [ "#### Obtain ground truth posterior via MCMC\n", "\n", - "We obtain a ground-truth posterior via MCMC by using the PotentialFunctionProvider.\n", + "We obtain a ground-truth posterior via MCMC by using the analytical Binomial-Gamma\n", + "likelihood as before. \n", "\n", - "For that, we first the define the actual prior, i.e., the distribution over the parameter we want to infer (not the proposal).\n", + "For that, we first the define the actual prior, i.e., the distribution over the\n", + "parameter we want to infer (not the proposal). (dropping the uniform prior over\n", + "experimental conditions).\n", "\n", - "Thus, we leave out the discrete prior over experimental conditions.\n" + "Additionally, we pass the entire batch of i.i.d. data `x_o` and matching batch of i.i.d.\n", + "`conditions`.\n" ] }, { @@ -578,18 +539,10 @@ "execution_count": 14, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/janteusen/qode/sbi/sbi/inference/posteriors/mcmc_posterior.py:115: UserWarning: The default value for thinning in MCMC sampling has been changed from 10 to 1. This might cause the results differ from the last benchmark.\n", - " thin = _process_thin_default(thin)\n" - ] - }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ad169fdca3da40649e6e1c329460e355", + "model_id": "ee7db79e47674ed3b2574c26b09eb0b2", "version_major": 2, "version_minor": 0 }, @@ -617,8 +570,7 @@ " BinomialGammaPotential(\n", " prior,\n", " x_o,\n", - " concentration_scaling=float(theta_o[0, 2])\n", - " + 1.0, # add one because the sim_wrapper adds one (see above)\n", + " concentration_scaling=conditions,\n", " ),\n", " theta_transform=prior_transform,\n", " proposal=prior,\n", @@ -630,7 +582,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Train MNLE including experimental conditions\n" + "### Train MNLE including experimental conditions\n", + "\n", + "Next, we use the combined parameters and conditions (`theta`) and the corresponding\n", + "simulated data to train `MNLE`.\n" ] }, { @@ -642,6 +597,8 @@ "name": "stderr", "output_type": "stream", "text": [ + "/Users/janteusen/qode/sbi/sbi/inference/trainers/base.py:271: UserWarning: Z-scoring these simulation outputs resulted in 4 unique datapoints. Before z-scoring, it had been 19872. This can occur due to numerical inaccuracies when the data covers a large range of values. Consider either setting `z_score_x=False` (but beware that this can be problematic for training the NN) or exclude outliers from your dataset. Note: if you have already set `z_score_x=False`, this warning will still be displayed, but you can ignore it.\n", + " warn_if_zscoring_changes_data(x)\n", "/Users/janteusen/qode/sbi/sbi/neural_nets/factory.py:205: UserWarning: The mixed neural likelihood estimator assumes that x contains continuous data in the first n-1 columns (e.g., reaction times) and categorical data in the last column (e.g., corresponding choices). If this is not the case for the passed `x` do not use this function.\n", " return model_builders[model](batch_x=batch_x, batch_y=batch_theta, **kwargs)\n" ] @@ -650,12 +607,13 @@ "name": "stdout", "output_type": "stream", "text": [ - " Neural network successfully converged after 60 epochs." + " Neural network successfully converged after 75 epochs." ] } ], "source": [ - "trainer = MNLE(proposal)\n", + "estimator_builder = likelihood_nn(model=\"mnle\", z_score_x=None) # we don't want to z-score the binary data.\n", + "trainer = MNLE(proposal, estimator_builder)\n", "estimator = trainer.append_simulations(theta, x).train()" ] }, @@ -681,28 +639,147 @@ "outputs": [ { "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4f887f2ba37a4782964e838895cfc39e", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "torch.Size([1, 3])" + "Running vectorized MCMC with 20 chains: 0%| | 0/3000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Finally, we can compare the ground truth conditional posterior with the\n", + "# MNLE-conditional posterior.\n", + "fig, ax = pairplot(\n", + " [\n", + " prior.sample((1000,)),\n", + " true_posterior_samples,\n", + " conditional_samples,\n", + " ],\n", + " points=theta_o,\n", + " diag=\"kde\",\n", + " upper=\"contour\",\n", + " diag_kwargs=dict(bins=100),\n", + " upper_kwargs=dict(levels=[0.95]),\n", + " fig_kwargs=dict(\n", + " points_offdiag=dict(marker=\"*\", markersize=10),\n", + " points_colors=[\"k\"],\n", + "\n", + " ),\n", + " labels=[r\"$\\beta$\", r\"$\\rho$\"],\n", + " figsize=(6, 6),\n", + ")\n", + "\n", + "plt.sca(ax[1, 1])\n", + "plt.legend(\n", + " [\"Prior\", \"Reference\", \"MNLE\", r\"$\\theta_o$\"],\n", + " frameon=False,\n", + " fontsize=12,\n", + ");\n", + "print(\"c2st between true and MNLE posterior:\", c2st(true_posterior_samples, conditional_samples).item())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "They match accurately, showing that we can indeed post-hoc condition the trained MNLE likelihood on different experimental conditions.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference with multiple subjects, trials, and conditions\n", + "\n", + "Note that we can also do inference for multiple `x_os` (e.g., subjects) with varying\n", + "numbers of trails and experimental conditions - all without retraining the MNLE.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "54115a1a0f534028b377fa5aa4661dc4", + "model_id": "ed79d139f3804547ab14ea8dcdea856e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running vectorized MCMC with 20 chains: 0%| | 0/3000 [00:00" ] @@ -754,15 +845,11 @@ } ], "source": [ - "# Finally, we can compare the ground truth conditional posterior with the\n", - "# MNLE-conditional posterior.\n", + "# Plotting all three posteriors in one pairplot.\n", + "\n", "fig, ax = pairplot(\n", - " [\n", - " prior.sample((1000,)),\n", - " true_posterior_samples,\n", - " conditional_samples,\n", - " ],\n", - " points=theta_o,\n", + " [prior.sample((1000,))] + posterior_samples,\n", + " # points=theta_o,\n", " diag=\"kde\",\n", " upper=\"contour\",\n", " diag_kwargs=dict(bins=100),\n", @@ -770,13 +857,15 @@ " fig_kwargs=dict(\n", " points_offdiag=dict(marker=\"*\", markersize=10),\n", " points_colors=[\"k\"],\n", + "\n", " ),\n", " labels=[r\"$\\beta$\", r\"$\\rho$\"],\n", + " figsize=(10, 10),\n", ")\n", "\n", "plt.sca(ax[1, 1])\n", "plt.legend(\n", - " [\"Prior\", \"Reference\", \"MNLE\", r\"$\\theta_o$\"],\n", + " [\"prior\"] + [f\"Subject {idx+1}\" for idx in range(num_subjects)],\n", " frameon=False,\n", " fontsize=12,\n", ");" @@ -786,13 +875,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "They match accurately, showing that we can indeed post-hoc condition the trained MNLE likelihood on different experimental conditions.\n" + "Note how the posteriors are becoming more narrow with increasing number of trials\n", + "(subject 1: 10 trials vs. subject 3: 30 trials)." ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('sbi')", + "display_name": "sbi_env", "language": "python", "name": "python3" }, @@ -806,12 +896,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.4" - }, - "vscode": { - "interpreter": { - "hash": "9ef9b53a5ce850816b9705a866e49207a37a04a71269aa157d9f9ab944ea42bf" - } + "version": "3.10.13" } }, "nbformat": 4, diff --git a/tutorials/example_01_utils.py b/tutorials/example_01_utils.py new file mode 100644 index 000000000..620058d05 --- /dev/null +++ b/tutorials/example_01_utils.py @@ -0,0 +1,60 @@ +from typing import Union + +import torch +from torch import Tensor +from torch.distributions import Binomial, Distribution, InverseGamma + +from sbi.inference.potentials.base_potential import BasePotential +from sbi.utils.torchutils import atleast_2d + + +class BinomialGammaPotential(BasePotential): + """Binomial-Gamma potential for mixed data.""" + + def __init__( + self, + prior: Distribution, + x_o: Tensor, + concentration_scaling: Union[Tensor, float] = 1.0, + device="cpu", + ): + super().__init__(prior, x_o, device) + + # concentration_scaling needs to be a float or match the batch size + if isinstance(concentration_scaling, Tensor): + num_trials = x_o.shape[0] + assert concentration_scaling.shape[0] == num_trials + + # Reshape to match convention (batch_size, num_trials, *event_shape) + concentration_scaling = concentration_scaling.reshape(1, num_trials, -1) + + self.concentration_scaling = concentration_scaling + + def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: + theta = atleast_2d(theta) + + with torch.set_grad_enabled(track_gradients): + iid_ll = self.iid_likelihood(theta) + + return iid_ll + self.prior.log_prob(theta) + + def iid_likelihood(self, theta: Tensor) -> Tensor: + batch_size = theta.shape[0] + num_trials = self.x_o.shape[0] + theta = theta.reshape(batch_size, 1, -1) + beta, rho = theta[:, :, :1], theta[:, :, 1:] + + # vectorized + logprob_choices = Binomial(probs=rho).log_prob( + self.x_o[:, 1:].reshape(1, num_trials, -1) + ) + + logprob_rts = InverseGamma( + concentration=self.concentration_scaling * torch.ones_like(beta), + rate=beta, + ).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1)) + + joint_likelihood = (logprob_choices + logprob_rts).squeeze() + + assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]]) + return joint_likelihood.sum(1)