Skip to content

Commit

Permalink
Allow 1D prior with batch dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Sep 26, 2024
1 parent 4a7ed2b commit 1d4ee7a
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 2 deletions.
6 changes: 6 additions & 0 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sbi.utils.user_input_checks_utils import (
CustomPriorWrapper,
MultipleIndependent,
OneDimPriorWrapper,
PytorchReturnTypeWrapper,
)

Expand Down Expand Up @@ -220,6 +221,11 @@ def process_pytorch_prior(prior: Distribution) -> Tuple[Distribution, int, bool]
# This will fail for float64 priors.
check_prior_return_type(prior)

# Potentially required wrapper if the prior returns an additional sample dimension
# for `.log_prob()`.
if prior.log_prob(prior.sample(torch.Size((10,)))).shape == torch.Size([10, 1]):
prior = OneDimPriorWrapper(prior, validate_args=False)

theta_numel = prior.sample().numel()

return prior, theta_numel, False
Expand Down
66 changes: 66 additions & 0 deletions sbi/utils/user_input_checks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,69 @@ def build_support(
support = constraints.interval(lower_bound, upper_bound)

return support


class OneDimPriorWrapper(Distribution):
"""Wrap batched 1D distributions to get rid of the batch dim of `.log_prob()`.
1D pytorch distributions such as `torch.distributions.Exponential`, `.Uniform`, or
`.Normal` do not, by default return __any__ sample or batch dimension. E.g.:
```python
dist = torch.distributions.Exponential(torch.tensor(3.0))
dist.sample((10,)).shape # (10,)
```
`sbi` will raise an error that the sample dimension is missing. A simple solution is
to add a batch dimension to `dist` as follows:
```python
dist = torch.distributions.Exponential(torch.tensor([3.0]))
dist.sample((10,)).shape # (10, 1)
```
Unfortunately, this `dist` will return the batch dimension also for `.log_prob():
```python
dist = torch.distributions.Exponential(torch.tensor([3.0]))
samples = dist.sample((10,))
dist.log_prob(samples).shape # (10, 1)
```
This will lead to unexpected errors in `sbi`. The point of this class is to wrap
those batched 1D distributions to get rid of their batch dimension in `.log_prob()`.
"""

def __init__(
self,
prior: Distribution,
validate_args=None,
) -> None:
super().__init__(
batch_shape=prior.batch_shape,
event_shape=prior.event_shape,
validate_args=(
prior._validate_args if validate_args is None else validate_args
),
)
self.prior = prior

def sample(self, *args, **kwargs) -> Tensor:
return self.prior.sample(*args, **kwargs)

def log_prob(self, *args, **kwargs) -> Tensor:
"""Override the log_prob method to get rid of the additional batch dimension."""
return self.prior.log_prob(*args, **kwargs)[..., 0]

@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
return self.prior.arg_constraints

@property
def support(self):
return self.prior.support

@property
def mean(self) -> Tensor:
return self.prior.mean

@property
def variance(self) -> Tensor:
return self.prior.variance
20 changes: 18 additions & 2 deletions tests/user_input_checks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
import torch
from pyknos.mdn.mdn import MultivariateGaussianMDN
from torch import Tensor, eye, nn, ones, zeros
from torch.distributions import Beta, Distribution, Gamma, MultivariateNormal, Uniform
from torch.distributions import (
Beta,
Distribution,
Exponential,
Gamma,
MultivariateNormal,
Uniform,
)

from sbi.inference import NPE_A, NPE_C, simulate_for_sbi
from sbi.inference.posteriors.direct_posterior import DirectPosterior
Expand All @@ -27,6 +34,7 @@
from sbi.utils.user_input_checks_utils import (
CustomPriorWrapper,
MultipleIndependent,
OneDimPriorWrapper,
PytorchReturnTypeWrapper,
)

Expand Down Expand Up @@ -93,6 +101,11 @@ def matrix_simulator(theta):
BoxUniform(zeros(3, dtype=torch.float64), ones(3, dtype=torch.float64)),
dict(),
),
(
OneDimPriorWrapper,
Exponential(torch.tensor([3.0])),
dict(),
),
),
)
def test_prior_wrappers(wrapper, prior, kwargs):
Expand All @@ -118,6 +131,9 @@ def test_prior_wrappers(wrapper, prior, kwargs):
# Test transform
mcmc_transform(prior)

# For 1D priors, the `log_prob()` should not have a batch dim.
assert len(prior.log_prob(prior.sample((10,))).shape) == 1


def test_reinterpreted_batch_dim_prior():
"""Test whether the right warning and error are raised for reinterpreted priors."""
Expand Down Expand Up @@ -268,7 +284,6 @@ def test_prepare_sbi_problem(simulator: Callable, prior):
prior: prior as defined by the user (pytorch, scipy, custom)
x_shape: shape of data as defined by the user.
"""

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
check_sbi_inputs(simulator, prior)
Expand Down Expand Up @@ -308,6 +323,7 @@ def test_prepare_sbi_problem(simulator: Callable, prior):
MultivariateNormal(zeros(2), eye(2)),
),
),
(diagonal_linear_gaussian, Exponential(torch.tensor([3.0]))),
),
)
def test_inference_with_user_sbi_problems(
Expand Down

0 comments on commit 1d4ee7a

Please sign in to comment.