diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index 22c056a13..f407d632f 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -17,6 +17,7 @@ from sbi.utils.user_input_checks_utils import ( CustomPriorWrapper, MultipleIndependent, + OneDimPriorWrapper, PytorchReturnTypeWrapper, ) @@ -220,6 +221,11 @@ def process_pytorch_prior(prior: Distribution) -> Tuple[Distribution, int, bool] # This will fail for float64 priors. check_prior_return_type(prior) + # Potentially required wrapper if the prior returns an additional sample dimension + # for `.log_prob()`. + if prior.log_prob(prior.sample(torch.Size((10,)))).shape == torch.Size([10, 1]): + prior = OneDimPriorWrapper(prior, validate_args=False) + theta_numel = prior.sample().numel() return prior, theta_numel, False diff --git a/sbi/utils/user_input_checks_utils.py b/sbi/utils/user_input_checks_utils.py index 184d167c5..554a6f8f2 100644 --- a/sbi/utils/user_input_checks_utils.py +++ b/sbi/utils/user_input_checks_utils.py @@ -373,3 +373,69 @@ def build_support( support = constraints.interval(lower_bound, upper_bound) return support + + +class OneDimPriorWrapper(Distribution): + """Wrap batched 1D distributions to get rid of the batch dim of `.log_prob()`. + + 1D pytorch distributions such as `torch.distributions.Exponential`, `.Uniform`, or + `.Normal` do not, by default return __any__ sample or batch dimension. E.g.: + ```python + dist = torch.distributions.Exponential(torch.tensor(3.0)) + dist.sample((10,)).shape # (10,) + ``` + + `sbi` will raise an error that the sample dimension is missing. A simple solution is + to add a batch dimension to `dist` as follows: + ```python + dist = torch.distributions.Exponential(torch.tensor([3.0])) + dist.sample((10,)).shape # (10, 1) + ``` + + Unfortunately, this `dist` will return the batch dimension also for `.log_prob(): + ```python + dist = torch.distributions.Exponential(torch.tensor([3.0])) + samples = dist.sample((10,)) + dist.log_prob(samples).shape # (10, 1) + ``` + + This will lead to unexpected errors in `sbi`. The point of this class is to wrap + those batched 1D distributions to get rid of their batch dimension in `.log_prob()`. + """ + + def __init__( + self, + prior: Distribution, + validate_args=None, + ) -> None: + super().__init__( + batch_shape=prior.batch_shape, + event_shape=prior.event_shape, + validate_args=( + prior._validate_args if validate_args is None else validate_args + ), + ) + self.prior = prior + + def sample(self, *args, **kwargs) -> Tensor: + return self.prior.sample(*args, **kwargs) + + def log_prob(self, *args, **kwargs) -> Tensor: + """Override the log_prob method to get rid of the additional batch dimension.""" + return self.prior.log_prob(*args, **kwargs)[..., 0] + + @property + def arg_constraints(self) -> Dict[str, constraints.Constraint]: + return self.prior.arg_constraints + + @property + def support(self): + return self.prior.support + + @property + def mean(self) -> Tensor: + return self.prior.mean + + @property + def variance(self) -> Tensor: + return self.prior.variance diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py index 2766d8513..cd275a7b0 100644 --- a/tests/user_input_checks_test.py +++ b/tests/user_input_checks_test.py @@ -10,7 +10,14 @@ import torch from pyknos.mdn.mdn import MultivariateGaussianMDN from torch import Tensor, eye, nn, ones, zeros -from torch.distributions import Beta, Distribution, Gamma, MultivariateNormal, Uniform +from torch.distributions import ( + Beta, + Distribution, + Exponential, + Gamma, + MultivariateNormal, + Uniform, +) from sbi.inference import NPE_A, NPE_C, simulate_for_sbi from sbi.inference.posteriors.direct_posterior import DirectPosterior @@ -27,6 +34,7 @@ from sbi.utils.user_input_checks_utils import ( CustomPriorWrapper, MultipleIndependent, + OneDimPriorWrapper, PytorchReturnTypeWrapper, ) @@ -93,6 +101,11 @@ def matrix_simulator(theta): BoxUniform(zeros(3, dtype=torch.float64), ones(3, dtype=torch.float64)), dict(), ), + ( + OneDimPriorWrapper, + Exponential(torch.tensor([3.0])), + dict(), + ), ), ) def test_prior_wrappers(wrapper, prior, kwargs): @@ -118,6 +131,9 @@ def test_prior_wrappers(wrapper, prior, kwargs): # Test transform mcmc_transform(prior) + # For 1D priors, the `log_prob()` should not have a batch dim. + assert len(prior.log_prob(prior.sample((10,))).shape) == 1 + def test_reinterpreted_batch_dim_prior(): """Test whether the right warning and error are raised for reinterpreted priors.""" @@ -268,7 +284,6 @@ def test_prepare_sbi_problem(simulator: Callable, prior): prior: prior as defined by the user (pytorch, scipy, custom) x_shape: shape of data as defined by the user. """ - prior, _, prior_returns_numpy = process_prior(prior) simulator = process_simulator(simulator, prior, prior_returns_numpy) check_sbi_inputs(simulator, prior) @@ -308,6 +323,7 @@ def test_prepare_sbi_problem(simulator: Callable, prior): MultivariateNormal(zeros(2), eye(2)), ), ), + (diagonal_linear_gaussian, Exponential(torch.tensor([3.0]))), ), ) def test_inference_with_user_sbi_problems(