Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NLE with multiple iid conditions #1331

Merged
merged 6 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 134 additions & 2 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Callable, Optional, Tuple
import warnings
from typing import Callable, List, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -115,6 +116,52 @@
)
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore

def condition_on_theta(
janfb marked this conversation as resolved.
Show resolved Hide resolved
self, theta_condition: Tensor, dims_to_sample: List[int]
) -> Callable:
"""Returns a potential function conditioned on a subset of theta dimensions.

The condition is a part of theta, but is assumed to correspond to a batch of iid
janfb marked this conversation as resolved.
Show resolved Hide resolved
x_o. For example, it can be a batch of experimental conditions that corresponds
to a batch of i.i.d. trials in x_o.

Args:
theta_condition: The condition values to be conditioned.
dims_to_sample: The indices of the columns in theta that will be sampled,
i.e., that *not* conditioned. For example, if original theta has shape
`(batch_dim, 3)`, and `dims_to_sample=[0, 1]`, then the potential will
set `theta[:, 3] = theta_condition` at inference time.

Returns:
A potential function conditioned on the theta_condition.
"""

assert self.x_is_iid, "Conditioning is only supported for iid data."

Check warning on line 139 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L139

Added line #L139 was not covered by tests

def conditioned_potential(

Check warning on line 141 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L141

Added line #L141 was not covered by tests
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
assert (

Check warning on line 144 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L144

Added line #L144 was not covered by tests
len(dims_to_sample) == theta.shape[1]
), "dims_to_sample must match the number of parameters to sample."
theta_without_condition = theta[:, dims_to_sample]
x_o = x_o if x_o is not None else self.x_o

Check warning on line 148 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L147-L148

Added lines #L147 - L148 were not covered by tests
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
if x_o.dim() < 3:
x_o = reshape_to_sample_batch_event(

Check warning on line 151 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L150-L151

Added lines #L150 - L151 were not covered by tests
x_o, event_shape=x_o.shape[1:], leading_is_sample=self.x_is_iid
)

return _log_likelihood_over_iid_conditions(

Check warning on line 155 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L155

Added line #L155 was not covered by tests
x=x_o,
theta_without_condition=theta_without_condition,
condition=theta_condition,
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)

return conditioned_potential

Check warning on line 163 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L163

Added line #L163 was not covered by tests


def _log_likelihoods_over_trials(
x: Tensor,
Expand Down Expand Up @@ -172,6 +219,78 @@
return log_likelihood_trial_sum


def _log_likelihood_over_iid_conditions(
x: Tensor,
theta_without_condition: Tensor,
condition: Tensor,
estimator: ConditionalDensityEstimator,
track_gradients: bool = False,
) -> Tensor:
"""Returns $\\log(p(x_o|\theta, condition)$, where x_o is a batch of iid data, and
janfb marked this conversation as resolved.
Show resolved Hide resolved
condition is a matching batch of conditions.

This function is different from `_log_likelihoods_over_trials` in that it moves the
iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
the likelihood estimator is conditioned on a batch of conditions that are iid with
the batch of `x`. It avoids the evaluation of the likelihood for every combination
of `x` and `condition`. Instead, it manually constructs a batch covering all
combination of iid trials and theta batch and reshapes to sum over the iid
likelihoods.

Args:
x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
observations.
theta_without_condition: Batch of parameters `(theta_batch_dim,
num_parameters)`.
condition: Batch of conditions of shape `(sample_dim, num_conditions)`, must
match x's `sample_dim`.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.

Returns:
log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
theta_batch_dim, summed over all i.i.d. trials. Shape
`(x_batch_dim, theta_batch_dim)`.
"""
assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
assert (
condition.dim() == 2
), "condition must have shape (sample_dim, num_conditions)."
assert (
theta_without_condition.dim() == 2
), "theta must have shape (batch_dim, num_parameters)."
num_trials, num_xs = x.shape[:2]
num_thetas = theta_without_condition.shape[0]
assert (
condition.shape[0] == num_trials
), "Condition batch size must match the number of iid trials in x."

# move the iid batch dimension onto the batch dimension of theta and repeat it there
x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1)

# construct theta and condition to cover all trial-theta combinations
theta_with_condition = torch.cat(
[
theta_without_condition.repeat(num_trials, 1), # repeat ABAB
condition.repeat_interleave(num_thetas, dim=0), # repeat AABB
],
dim=-1,
)
janfb marked this conversation as resolved.
Show resolved Hide resolved

with torch.set_grad_enabled(track_gradients):
# Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
log_likelihood_trial_batch = estimator.log_prob(
x_repeated, condition=theta_with_condition
)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
num_xs, num_trials, num_thetas
).sum(1)

return log_likelihood_trial_sum


def mixed_likelihood_estimator_based_potential(
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
Expand All @@ -192,6 +311,13 @@
to unconstrained space.
"""

warnings.warn(

Check warning on line 314 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L314

Added line #L314 was not covered by tests
"This function is deprecated and will be removed in a future release. Use "
"`likelihood_estimator_based_potential` instead.",
DeprecationWarning,
stacklevel=2,
)

device = str(next(likelihood_estimator.discrete_net.parameters()).device)

potential_fn = MixedLikelihoodBasedPotential(
Expand All @@ -212,6 +338,13 @@
):
super().__init__(likelihood_estimator, prior, x_o, device)

warnings.warn(

Check warning on line 341 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L341

Added line #L341 was not covered by tests
"This function is deprecated and will be removed in a future release. Use "
"`LikelihoodBasedPotential` instead.",
DeprecationWarning,
stacklevel=2,
)

def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore

Expand All @@ -231,7 +364,6 @@
with torch.set_grad_enabled(track_gradients):
# Call the specific log prob method of the mixed likelihood estimator as
# this optimizes the evaluation of the discrete data part.
# TODO log_prob_iid
log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
input=x,
condition=theta.to(self.device),
Expand Down
6 changes: 2 additions & 4 deletions sbi/inference/trainers/nle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.distributions import Distribution

from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimator
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
Expand Down Expand Up @@ -155,9 +155,7 @@ def build_posterior(
(
potential_fn,
theta_transform,
) = mixed_likelihood_estimator_based_potential(
likelihood_estimator=likelihood_estimator, prior=prior, x_o=None
)
) = likelihood_estimator_based_potential(likelihood_estimator, prior, x_o=None)

if sample_with == "mcmc":
self._posterior = MCMCPosterior(
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/conditional_density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
masked outside of prior.
"""
condition = torch.atleast_2d(condition)
if condition.shape[0] != 1:
if condition.shape[0] > 1:
raise ValueError("Condition with batch size > 1 not supported.")

self.potential_fn = potential_fn
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -

if num_unique_z < num_unique * (1 - duplicate_tolerance):
warnings.warn(
"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
"datapoints. Before z-scoring, it had been {num_unique}. This can "
f"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
f"datapoints. Before z-scoring, it had been {num_unique}. This can "
"occur due to numerical inaccuracies when the data covers a large "
"range of values. Consider either setting `z_score_x=False` (but "
"beware that this can be problematic for training the NN) or exclude "
Expand Down
Loading
Loading