-
Notifications
You must be signed in to change notification settings - Fork 154
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a4f060d
commit 6fa9b96
Showing
39 changed files
with
382 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from datetime import datetime | ||
import pytest | ||
|
||
import shutil | ||
|
||
|
||
def pytest_addoption(parser): | ||
parser.addoption( | ||
"--print-harvest", | ||
action="store_true", | ||
default=False, | ||
help="Print the harvest results at the end of the test session", | ||
) | ||
|
||
|
||
def pytest_terminal_summary(terminalreporter, exitstatus, config): | ||
# Only print the harvest results if the --print-harvest flag is used | ||
if config.getoption("--print-harvest"): | ||
# Dynamically center the summary title in the terminal width | ||
terminal_width = shutil.get_terminal_size().columns | ||
summary_text = " short test summary info " | ||
centered_line = summary_text.center(terminal_width, '=') | ||
terminalreporter.write_line(centered_line) | ||
|
||
|
||
@pytest.mark.parametrize('p', ['world', 'self'], ids=str) | ||
def test_foo(p, results_bag): | ||
""" | ||
A dummy test, parametrized so that it is executed twice | ||
""" | ||
|
||
# Let's store some things in the results bag | ||
results_bag.nb_letters = len(p) | ||
results_bag.current_time = datetime.now().isoformat() | ||
|
||
|
||
def test_synthesis(fixture_store): | ||
""" | ||
In this test we inspect the contents of the fixture store so far, and | ||
check that the 'results_bag' entry contains a dict <test_id>: <results_bag> | ||
""" | ||
# print the keys in the store | ||
results = fixture_store["results_bag"] | ||
|
||
# print what is available for the 'results_bag' entry | ||
print("\n--- Harvested Test Results ---") | ||
for k, v in results.items(): | ||
print(k) | ||
for kk, vv in v.items(): | ||
print(kk, vv) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,17 @@ | ||
from .gaussian_linear import GaussianLinear | ||
from .linear_mvg import LinearMVG2d | ||
from .slcp import Slcp | ||
from .two_moons import TwoMoons | ||
|
||
|
||
def get_task(name: str): | ||
if name == "two_moons": | ||
return TwoMoons() | ||
elif name == "linear_mvg_2d": | ||
return LinearMVG2d() | ||
elif name == "gaussian_linear": | ||
return GaussianLinear() | ||
elif name == "slcp": | ||
return Slcp() | ||
else: | ||
raise ValueError(f"Unknown task {name}") |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from functools import partial | ||
from typing import Callable | ||
|
||
import torch | ||
from torch.distributions import Distribution, MultivariateNormal | ||
|
||
from sbi.simulators.linear_gaussian import ( | ||
diagonal_linear_gaussian, | ||
true_posterior_linear_gaussian_mvn_prior, | ||
) | ||
|
||
from .base_task import Task | ||
|
||
|
||
class GaussianLinear(Task): | ||
def __init__(self): | ||
self.simulator_scale = 0.1 | ||
self.dim = 5 | ||
super().__init__("gaussian_linear") | ||
|
||
def theta_dim(self) -> int: | ||
return self.dim | ||
|
||
def x_dim(self) -> int: | ||
return self.dim | ||
|
||
def get_reference_posterior_samples(self, idx: int) -> torch.Tensor: | ||
x_o = self.get_observation(idx) | ||
posterior = true_posterior_linear_gaussian_mvn_prior( | ||
x_o, | ||
torch.zeros(self.dim), | ||
self.simulator_scale * torch.eye(self.dim), | ||
torch.zeros(self.dim), | ||
torch.eye(self.dim), | ||
) | ||
|
||
return posterior.sample((10_000,)) | ||
|
||
def get_true_parameters(self, idx: int) -> torch.Tensor: | ||
torch.manual_seed(idx) | ||
return self.get_prior().sample() | ||
|
||
def get_observation(self, idx: int) -> torch.Tensor: | ||
theta_o = self.get_true_parameters(idx) | ||
x_o = self.get_simulator()(theta_o[None, :])[0] | ||
return x_o | ||
|
||
def get_prior(self) -> Distribution: | ||
return MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim)) | ||
|
||
def get_simulator(self) -> Callable: | ||
return partial( | ||
diagonal_linear_gaussian, | ||
std=self.simulator_scale, | ||
) |
Oops, something went wrong.