Skip to content

Commit

Permalink
Raise for sample_conditional in snpe-a
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Aug 3, 2021
1 parent c3a3f46 commit aad5c5e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 37 deletions.
26 changes: 18 additions & 8 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
atleast_2d,
batched_first_of_batch,
ensure_theta_batched,
atleast_2d_float32_tensor,
)
from sbi.utils.conditional_density import extract_and_transform_mog, condition_mog

Expand Down Expand Up @@ -475,15 +476,24 @@ def sample_conditional(
Returns:
Samples from conditional posterior.
"""
if type(self.net._distribution) is mdn:
if not hasattr(self.net, "_distribution"):
raise NotImplementedError(
"`sample_conditional` is not implemented for SNPE-A."
)

net = self.net
x = atleast_2d_float32_tensor(self._x_else_default_x(x))

if type(net._distribution) is mdn:
condition = atleast_2d_float32_tensor(condition)
num_samples = torch.Size(sample_shape).numel()

logits, means, precfs, _ = extract_and_transform_mog(self, x)
logits, means, precfs, _ = extract_and_transform_mog(nn=net, context=x)
logits, means, precfs, _ = condition_mog(
self._prior, condition, dims_to_sample, logits, means, precfs
)

# Currently difficult to integrate `sample_posterior_within_prior`
# Currently difficult to integrate `sample_posterior_within_prior`.
warn(
"Sampling MoG analytically. "
"Some of the samples might not be within the prior support!"
Expand Down Expand Up @@ -515,24 +525,24 @@ def log_prob_conditional(
"""Evaluates the conditional posterior probability of a MDN at a context x for
a value theta given a condition.
This function only works for MDN based posteriors, becuase evaluation is done
This function only works for MDN based posteriors, becuase evaluation is done
analytically. For all other density estimators a `NotImplementedError` will be
raised!
raised!
Args:
theta: Parameters $\theta$.
condition: Parameter set that all dimensions not specified in
`dims_to_sample` will be fixed to. Should contain dim_theta elements,
i.e. it could e.g. be a sample from the posterior distribution.
The entries at all `dims_to_sample` will be ignored.
dims_to_evaluate: Which dimensions to evaluate the sample for.
dims_to_evaluate: Which dimensions to evaluate the sample for.
The dimensions not specified in `dims_to_evaluate` will be fixed to values given in `condition`.
x: Conditioning context for posterior $p(\theta|x)$. If not provided,
fall back onto `x` passed to `set_default_x()`.
Returns:
log_prob: `(len(θ),)`-shaped normalized (!) log posterior probability
$\log p(\theta|x) for θ in the support of the prior, -∞ (corresponding
log_prob: `(len(θ),)`-shaped normalized (!) log posterior probability
$\log p(\theta|x) for θ in the support of the prior, -∞ (corresponding
to 0 probability) outside.
"""

Expand Down
27 changes: 11 additions & 16 deletions sbi/utils/conditional_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor

from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn
from pyknos.nflows.flows import Flow
from sbi.utils.torchutils import ensure_theta_batched
from sbi.utils.torchutils import BoxUniform

Expand Down Expand Up @@ -335,7 +336,7 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor:


def extract_and_transform_mog(
posterior: "DirectPosterior", context: Tensor = None
nn: Flow, context: Tensor = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Extracts the Mixture of Gaussians (MoG) parameters
from an MDN based DirectPosterior at either the default x or input x.
Expand All @@ -344,26 +345,21 @@ def extract_and_transform_mog(
posterior: DirectPosterior instance.
context: Conditioning context for posterior $p(\theta|x)$. If not provided,
fall back onto `x` passed to `set_default_x()`.
Returns:
norm_logits: Normalised log weights of the underyling MoG.
norm_logits: Normalised log weights of the underyling MoG.
(batch_size, n_mixtures)
means_transformed: Recentred and rescaled means of the underlying MoG
(batch_size, n_mixtures, n_dims)
precfs_transformed: Rescaled precision factors of the underlying MoG.
precfs_transformed: Rescaled precision factors of the underlying MoG.
(batch_size, n_mixtures, n_dims, n_dims)
sumlogdiag: Sum of the log of the diagonal of the precision factors
of the new conditional distribution. (batch_size, n_mixtures)
"""

# extract and rescale means, mixture componenets and covariances
nn = posterior.net
dist = nn._distribution

if context == None:
encoded_x = nn._embedding_net(posterior.default_x)
else:
encoded_x = nn._embedding_net(context)
encoded_x = nn._embedding_net(context)

logits, means, _, sumlogdiag, precfs = dist.get_mixture_components(encoded_x)
norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
Expand All @@ -384,7 +380,7 @@ def extract_and_transform_mog(


def condition_mog(
prior: "Prior",
prior: Any,
condition: Tensor,
dims: List[int],
logits: Tensor,
Expand All @@ -404,7 +400,7 @@ def condition_mog(
`condition`.
logits: Log weights of the MoG. (batch_size, n_mixtures)
means: Means of the MoG. (batch_size, n_mixtures, n_dims)
precfs: Precision factors of the MoG.
precfs: Precision factors of the MoG.
(batch_size, n_mixtures, n_dims, n_dims)
Raises:
Expand All @@ -413,21 +409,20 @@ def condition_mog(
Returns:
logits: Log weights of the conditioned MoG. (batch_size, n_mixtures)
means: Means of the conditioned MoG. (batch_size, n_mixtures, n_dims)
precfs_xx: Precision factors of the MoG.
precfs_xx: Precision factors of the MoG.
(batch_size, n_mixtures, n_dims, n_dims)
sumlogdiag: Sum of the log of the diagonal of the precision factors
of the new conditional distribution. (batch_size, n_mixtures)
"""

support = prior.support

n_mixtures, n_dims = means.shape[1:]

mask = torch.zeros(n_dims, dtype=bool)
mask[dims] = True

# check whether the condition is within the prior bounds
# Check whether the condition is within the prior bounds.
if type(prior) is torch.distributions.uniform.Uniform or type(prior) is BoxUniform:
support = prior.support.base_constraint
cond_ubound = support.upper_bound[~mask]
cond_lbound = support.lower_bound[~mask]
within_support = torch.logical_and(
Expand Down
23 changes: 10 additions & 13 deletions tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def test_c2st_snpe_on_linearGaussian(
target_samples = samples_true_posterior_linear_gaussian_uniform_prior(
x_o, likelihood_shift, likelihood_cov, prior=prior, num_samples=num_samples
)

simulator, prior = prepare_for_sbi(
lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior
)

inference = snpe_method(prior, show_progress_bars=False)

theta, x = simulate_for_sbi(
Expand Down Expand Up @@ -351,8 +351,7 @@ def test_api_snpe_c_posterior_correction(sample_with, mcmc_method, prior_str, se


@pytest.mark.slow
@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C])
def test_sample_conditional(snpe_method: type, set_seed):
def test_sample_conditional(set_seed):
"""
Test whether sampling from the conditional gives the same results as evaluating.
Expand Down Expand Up @@ -380,14 +379,11 @@ def simulator(theta):
else:
return linear_gaussian(theta, -likelihood_shift, likelihood_cov)

if snpe_method == SNPE_A:
net = utils.posterior_nn("mdn_snpe_a", hidden_features=20)
else:
net = utils.posterior_nn("maf", hidden_features=20)
net = utils.posterior_nn("maf", hidden_features=20)

simulator, prior = prepare_for_sbi(simulator, prior)

inference = snpe_method(prior, density_estimator=net, show_progress_bars=False)
inference = SNPE_C(prior, density_estimator=net, show_progress_bars=False)

# We need a pretty big dataset to properly model the bimodality.
theta, x = simulate_for_sbi(simulator, prior, 10000)
Expand Down Expand Up @@ -438,10 +434,10 @@ def simulator(theta):
max_err = np.max(error)
assert max_err < 0.0026


@pytest.mark.slow
@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C])
def test_mdn_conditional_density(snpe_method: type, num_dim: int = 3, cond_dim: int = 1):
"""Test whether the conditional density infered from MDN parameters of a
def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1):
"""Test whether the conditional density infered from MDN parameters of a
`DirectPosterior` matches analytical results for MVN. This uses a n-D joint and
conditions on the last m values to generate a conditional.
Expand Down Expand Up @@ -490,7 +486,7 @@ def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)

simulator, prior = prepare_for_sbi(simulator, prior)
inference = snpe_method(prior, show_progress_bars=False, density_estimator="mdn")
inference = SNPE_C(prior, show_progress_bars=False, density_estimator="mdn")

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=1000
Expand All @@ -507,6 +503,7 @@ def simulator(theta):
alg="analytic_mdn_conditioning_of_direct_posterior",
)


@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C])
def test_example_posterior(snpe_method: type):
"""Return an inferred `NeuralPosterior` for interactive examination."""
Expand Down

0 comments on commit aad5c5e

Please sign in to comment.