Skip to content

Commit

Permalink
track gradients only for hmc pyro methods
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler authored and janfb committed May 18, 2021
1 parent 1ccf785 commit 9b23625
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 18 deletions.
23 changes: 15 additions & 8 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -525,8 +526,10 @@ def __call__(
self.device = next(posterior_nn.parameters()).device
self.x = atleast_2d(x).to(self.device)

if mcmc_method in ("slice", "hmc", "nuts"):
return self.pyro_potential
if mcmc_method == "slice":
return partial(self.pyro_potential, track_gradients=False)
elif mcmc_method in ("hmc", "nuts"):
return partial(self.pyro_potential, track_gradients=True)
else:
return self.np_potential

Expand Down Expand Up @@ -556,7 +559,9 @@ def np_potential(self, theta: np.ndarray) -> ScalarFloat:

return target_log_prob

def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:
def pyro_potential(
self, theta: Dict[str, Tensor], track_gradients: bool = False
) -> Tensor:
r"""Return posterior log prob. of theta $p(\theta|x)$, -inf where outside prior.
Args:
Expand All @@ -568,11 +573,13 @@ def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:

theta = next(iter(theta.values()))

# Notice opposite sign to numpy.
# Move theta to device for evaluation.
log_prob_posterior = -self.posterior_nn.log_prob(
inputs=theta.to(self.device), context=self.x
).cpu()
with torch.set_grad_enabled(track_gradients):
# Notice opposite sign to numpy.
# Move theta to device for evaluation.
log_prob_posterior = -self.posterior_nn.log_prob(
inputs=theta.to(self.device),
context=self.x,
).cpu()

in_prior_support = within_support(self.prior, theta)

Expand Down
20 changes: 14 additions & 6 deletions sbi/inference/posteriors/likelihood_based_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from warnings import warn

Expand Down Expand Up @@ -376,19 +377,21 @@ def __call__(
self.device = next(likelihood_nn.parameters()).device
self.x = atleast_2d(x).to(self.device)

if mcmc_method in ("slice", "hmc", "nuts"):
return self.pyro_potential
if mcmc_method == "slice":
return partial(self.pyro_potential, track_gradients=False)
elif mcmc_method in ("hmc", "nuts"):
return partial(self.pyro_potential, track_gradients=True)
else:
return self.np_potential

def log_likelihood(self, theta: Tensor) -> Tensor:
def log_likelihood(self, theta: Tensor, track_gradients: bool = False) -> Tensor:
"""Return log likelihood of fixed data given a batch of parameters."""

log_likelihoods = LikelihoodBasedPosterior._log_likelihoods_over_trials(
x=self.x,
theta=ensure_theta_batched(theta).to(self.device),
net=self.likelihood_nn,
track_gradients=False,
track_gradients=track_gradients,
)

return log_likelihoods
Expand All @@ -407,7 +410,9 @@ def np_potential(self, theta: np.array) -> ScalarFloat:
# Notice opposite sign to pyro potential.
return self.log_likelihood(theta).cpu() + self.prior.log_prob(theta)

def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:
def pyro_potential(
self, theta: Dict[str, Tensor], track_gradients: bool = False
) -> Tensor:
r"""Return posterior log probability of parameters $p(\theta|x)$.
Args:
Expand All @@ -421,4 +426,7 @@ def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:

theta = next(iter(theta.values()))

return -(self.log_likelihood(theta).cpu() + self.prior.log_prob(theta))
return -(
self.log_likelihood(theta, track_gradients=track_gradients).cpu()
+ self.prior.log_prob(theta)
)
16 changes: 12 additions & 4 deletions sbi/inference/posteriors/ratio_based_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from warnings import warn

Expand Down Expand Up @@ -412,8 +413,10 @@ def __call__(
self.device = next(classifier.parameters()).device
self.x = atleast_2d(x).to(self.device)

if mcmc_method in ("slice", "hmc", "nuts"):
return self.pyro_potential
if mcmc_method == "slice":
return partial(self.pyro_potential, track_gradients=False)
elif mcmc_method in ("hmc", "nuts"):
return partial(self.pyro_potential, track_gradients=True)
else:
return self.np_potential

Expand All @@ -438,7 +441,9 @@ def np_potential(self, theta: np.array) -> ScalarFloat:
# Notice opposite sign to pyro potential.
return log_ratio.cpu() + self.prior.log_prob(theta)

def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:
def pyro_potential(
self, theta: Dict[str, Tensor], track_gradients: bool = False
) -> Tensor:
r"""Return potential for Pyro sampler.
Note: for Pyro this is the negative unnormalized posterior log prob.
Expand All @@ -458,7 +463,10 @@ def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:
theta = ensure_theta_batched(theta)

log_ratio = RatioBasedPosterior._log_ratios_over_trials(
self.x, theta.to(self.device), self.classifier, track_gradients=False
self.x,
theta.to(self.device),
self.classifier,
track_gradients=track_gradients,
)

return -(log_ratio.cpu() + self.prior.log_prob(theta))

0 comments on commit 9b23625

Please sign in to comment.