diff --git a/test_example.py b/test_example.py new file mode 100644 index 000000000..628324469 --- /dev/null +++ b/test_example.py @@ -0,0 +1,50 @@ +from datetime import datetime +import pytest + +import shutil + + +def pytest_addoption(parser): + parser.addoption( + "--print-harvest", + action="store_true", + default=False, + help="Print the harvest results at the end of the test session", + ) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + # Only print the harvest results if the --print-harvest flag is used + if config.getoption("--print-harvest"): + # Dynamically center the summary title in the terminal width + terminal_width = shutil.get_terminal_size().columns + summary_text = " short test summary info " + centered_line = summary_text.center(terminal_width, '=') + terminalreporter.write_line(centered_line) + + +@pytest.mark.parametrize('p', ['world', 'self'], ids=str) +def test_foo(p, results_bag): + """ + A dummy test, parametrized so that it is executed twice + """ + + # Let's store some things in the results bag + results_bag.nb_letters = len(p) + results_bag.current_time = datetime.now().isoformat() + + +def test_synthesis(fixture_store): + """ + In this test we inspect the contents of the fixture store so far, and + check that the 'results_bag' entry contains a dict : + """ + # print the keys in the store + results = fixture_store["results_bag"] + + # print what is available for the 'results_bag' entry + print("\n--- Harvested Test Results ---") + for k, v in results.items(): + print(k) + for kk, vv in v.items(): + print(kk, vv) diff --git a/tests/bm_test.py b/tests/bm_test.py index a23126ffc..a77d16481 100644 --- a/tests/bm_test.py +++ b/tests/bm_test.py @@ -1,31 +1,27 @@ import pytest import torch -from sbi.inference import NPE, NRE +from sbi.inference import FMPE, NLE, NPE, NPSE, NRE from sbi.utils.metrics import c2st from .mini_sbibm import get_task +# The probably should be some user control on this +SEED = 0 +TASKS = ["two_moons", "linear_mvg_2d", "gaussian_linear", "slcp"] +NUM_SIMULATIONS = 2000 +EVALUATION_POINTS = 4 # Currently only 3 observation tested for speed -@pytest.mark.benchmark -@pytest.mark.parametrize('task_name', ['two_moons'], ids=str) -@pytest.mark.parametrize('density_estimator', ["maf", "nsf"], ids=str) -def test_benchmark_npe_methods( - task_name, density_estimator, results_bag, method=None, num_simulations=1000, seed=0 -): - torch.manual_seed(seed) - task = get_task(task_name) - thetas, xs = task.get_data(num_simulations) - assert thetas.shape[0] == num_simulations - assert xs.shape[0] == num_simulations +TRAIN_KWARGS = { + # "training_batch_size": 200, # To speed up training +} - inference = NPE(density_estimator=density_estimator) - _ = inference.append_simulations(thetas, xs).train() +# Amortized benchmarking - posterior = inference.build_posterior() +def standard_eval_c2st_loop(posterior, task): metrics = [] - for i in range(1, 2): # Currently only one observation tested for speed + for i in range(1, EVALUATION_POINTS): x_o = task.get_observation(i) posterior_samples = task.get_reference_posterior_samples(i) approx_posterior_samples = posterior.sample((1000,), x=x_o) @@ -37,43 +33,117 @@ def test_benchmark_npe_methods( mean_c2st = sum(metrics) / len(metrics) # Convert to float rounded to 3 decimal places mean_c2st = float(f"{mean_c2st:.3f}") + return mean_c2st + + +DENSITY_estimators = ["mdn", "made", "maf", "nsf", "maf_rqs"] # "Kinda exhaustive" +DENSITY_estimators = ["maf", "nsf"] # Fast + + +@pytest.mark.benchmark +@pytest.mark.parametrize('task_name', TASKS, ids=str) +@pytest.mark.parametrize('density_estimator', DENSITY_estimators, ids=str) +def test_benchmark_npe_methods(task_name, density_estimator, results_bag): + torch.manual_seed(SEED) + task = get_task(task_name) + thetas, xs = task.get_data(NUM_SIMULATIONS) + prior = task.get_prior() + + print(thetas.shape, xs.shape) + + inference = NPE(prior, density_estimator=density_estimator) + _ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS) + posterior = inference.build_posterior() + + mean_c2st = standard_eval_c2st_loop(posterior, task) + + # Cache results results_bag.metric = mean_c2st - results_bag.num_simulations = num_simulations + results_bag.num_simulations = NUM_SIMULATIONS results_bag.task_name = task_name results_bag.method = "NPE_" + density_estimator @pytest.mark.benchmark -@pytest.mark.parametrize('task_name', ['two_moons'], ids=str) -def test_benchmark_nre_methods(task_name, results_bag, num_simulations=1000, seed=0): - torch.manual_seed(seed) +@pytest.mark.parametrize('task_name', TASKS, ids=str) +def test_benchmark_nre_methods(task_name, results_bag): + torch.manual_seed(SEED) task = get_task(task_name) - thetas, xs = task.get_data(num_simulations) + thetas, xs = task.get_data(NUM_SIMULATIONS) prior = task.get_prior() - assert thetas.shape[0] == num_simulations - assert xs.shape[0] == num_simulations inference = NRE(prior) - _ = inference.append_simulations(thetas, xs).train() + _ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS) posterior = inference.build_posterior() - metrics = [] - for i in range(1, 2): - x_o = task.get_observation(i) - posterior_samples = task.get_reference_posterior_samples(i) - approx_posterior_samples = posterior.sample((1000,), x=x_o) - if isinstance(approx_posterior_samples, tuple): - approx_posterior_samples = approx_posterior_samples[0] - c2st_val = c2st(posterior_samples[:1000], approx_posterior_samples) - metrics.append(c2st_val) - - mean_c2st = sum(metrics) / len(metrics) - # Convert to float rounded to 3 decimal places - mean_c2st = float(f"{mean_c2st:.3f}") + mean_c2st = standard_eval_c2st_loop(posterior, task) results_bag.metric = mean_c2st - results_bag.num_simulations = num_simulations + results_bag.num_simulations = NUM_SIMULATIONS results_bag.task_name = task_name results_bag.method = "NRE" + + +@pytest.mark.benchmark +@pytest.mark.parametrize('task_name', TASKS, ids=str) +def test_benchmark_nle_methods(task_name, results_bag): + torch.manual_seed(SEED) + task = get_task(task_name) + thetas, xs = task.get_data(NUM_SIMULATIONS) + prior = task.get_prior() + + inference = NLE(prior) + _ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS) + + posterior = inference.build_posterior() + + mean_c2st = standard_eval_c2st_loop(posterior, task) + + results_bag.metric = mean_c2st + results_bag.num_simulations = NUM_SIMULATIONS + results_bag.task_name = task_name + results_bag.method = "NLE" + + +@pytest.mark.benchmark +@pytest.mark.parametrize('task_name', TASKS, ids=str) +def test_benchmark_fmpe_methods(task_name, results_bag): + torch.manual_seed(SEED) + task = get_task(task_name) + thetas, xs = task.get_data(NUM_SIMULATIONS) + prior = task.get_prior() + + inference = FMPE(prior) + _ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS) + + posterior = inference.build_posterior() + + mean_c2st = standard_eval_c2st_loop(posterior, task) + + results_bag.metric = mean_c2st + results_bag.num_simulations = NUM_SIMULATIONS + results_bag.task_name = task_name + results_bag.method = "FMPE" + + +@pytest.mark.benchmark +@pytest.mark.parametrize('task_name', TASKS, ids=str) +def test_benchmark_npse_methods(task_name, results_bag): + torch.manual_seed(SEED) + task = get_task(task_name) + thetas, xs = task.get_data(NUM_SIMULATIONS) + prior = task.get_prior() + + inference = NPSE(prior) + _ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS) + + posterior = inference.build_posterior() + + mean_c2st = standard_eval_c2st_loop(posterior, task) + + results_bag.metric = mean_c2st + results_bag.num_simulations = NUM_SIMULATIONS + results_bag.task_name = task_name + results_bag.method = "NPSE" diff --git a/tests/conftest.py b/tests/conftest.py index 6f01cd9a6..2ec42c3a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,10 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +import pickle import shutil +from logging import warning +from pathlib import Path +from shutil import rmtree import pytest import torch @@ -79,7 +83,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): terminalreporter.write_line(colored_line) if harvested_fixture_data is not None: - terminalreporter.write_line("Harvested Fixture Data:") + terminalreporter.write_line("Amortized inference:") results = harvested_fixture_data["results_bag"] @@ -131,7 +135,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): val = data.get((m, t), "N/A") # Convert metric to string with formatting if needed # e.g. format(val, ".3f") if val is a float - val_str = str(val) + val_str = format(val, ".3f") row += val_str.center(task_col_widths[t] + 2) terminalreporter.write_line(row) @@ -149,3 +153,48 @@ def mcmc_params_accurate() -> dict: def mcmc_params_fast() -> dict: """Fixture for MCMC parameters for fast tests.""" return dict(num_chains=1, thin=1, warmup_steps=1) + + +# Pytest harvest xdist support - not sure if we need it (for me xdist is always slower). + + +# Define the folder in which temporary worker's results will be stored +RESULTS_PATH = Path('./.xdist_results/') +RESULTS_PATH.mkdir(exist_ok=True) + + +def pytest_harvest_xdist_init(): + # reset the recipient folder + if RESULTS_PATH.exists(): + rmtree(RESULTS_PATH) + RESULTS_PATH.mkdir(exist_ok=False) + return True + + +def pytest_harvest_xdist_worker_dump(worker_id, session_items, fixture_store): + # persist session_items and fixture_store in the file system + with open(RESULTS_PATH / ('%s.pkl' % worker_id), 'wb') as f: + try: + pickle.dump((session_items, fixture_store), f) + except Exception as e: + warning( + "Error while pickling worker %s's harvested results: " "[%s] %s", + (worker_id, e.__class__, e), + ) + return True + + +def pytest_harvest_xdist_load(): + # restore the saved objects from file system + workers_saved_material = dict() + for pkl_file in RESULTS_PATH.glob('*.pkl'): + wid = pkl_file.stem + with pkl_file.open('rb') as f: + workers_saved_material[wid] = pickle.load(f) + return workers_saved_material + + +def pytest_harvest_xdist_cleanup(): + # delete all temporary pickle files + rmtree(RESULTS_PATH) + return True diff --git a/tests/mini_sbibm/__init__.py b/tests/mini_sbibm/__init__.py index 4b9094cda..e6565c984 100644 --- a/tests/mini_sbibm/__init__.py +++ b/tests/mini_sbibm/__init__.py @@ -1,8 +1,17 @@ +from .gaussian_linear import GaussianLinear +from .linear_mvg import LinearMVG2d +from .slcp import Slcp from .two_moons import TwoMoons def get_task(name: str): if name == "two_moons": return TwoMoons() + elif name == "linear_mvg_2d": + return LinearMVG2d() + elif name == "gaussian_linear": + return GaussianLinear() + elif name == "slcp": + return Slcp() else: raise ValueError(f"Unknown task {name}") diff --git a/tests/mini_sbibm/files/slcp/samples_1.pt b/tests/mini_sbibm/files/slcp/samples_1.pt new file mode 100644 index 000000000..d0a1bd771 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_1.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_10.pt b/tests/mini_sbibm/files/slcp/samples_10.pt new file mode 100644 index 000000000..0de7efad6 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_10.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_2.pt b/tests/mini_sbibm/files/slcp/samples_2.pt new file mode 100644 index 000000000..f642597d3 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_2.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_3.pt b/tests/mini_sbibm/files/slcp/samples_3.pt new file mode 100644 index 000000000..640bcd1dc Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_3.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_4.pt b/tests/mini_sbibm/files/slcp/samples_4.pt new file mode 100644 index 000000000..1397dc02b Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_4.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_5.pt b/tests/mini_sbibm/files/slcp/samples_5.pt new file mode 100644 index 000000000..f2e8c35f1 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_5.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_6.pt b/tests/mini_sbibm/files/slcp/samples_6.pt new file mode 100644 index 000000000..091bb8143 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_6.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_7.pt b/tests/mini_sbibm/files/slcp/samples_7.pt new file mode 100644 index 000000000..edcbe6596 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_7.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_8.pt b/tests/mini_sbibm/files/slcp/samples_8.pt new file mode 100644 index 000000000..ec1cd0392 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_8.pt differ diff --git a/tests/mini_sbibm/files/slcp/samples_9.pt b/tests/mini_sbibm/files/slcp/samples_9.pt new file mode 100644 index 000000000..38b14665e Binary files /dev/null and b/tests/mini_sbibm/files/slcp/samples_9.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_1.pt b/tests/mini_sbibm/files/slcp/theta_o_1.pt new file mode 100644 index 000000000..ab4dc7bae Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_1.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_10.pt b/tests/mini_sbibm/files/slcp/theta_o_10.pt new file mode 100644 index 000000000..4056ae9de Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_10.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_2.pt b/tests/mini_sbibm/files/slcp/theta_o_2.pt new file mode 100644 index 000000000..529388c43 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_2.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_3.pt b/tests/mini_sbibm/files/slcp/theta_o_3.pt new file mode 100644 index 000000000..97e333ced Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_3.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_4.pt b/tests/mini_sbibm/files/slcp/theta_o_4.pt new file mode 100644 index 000000000..27f22b885 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_4.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_5.pt b/tests/mini_sbibm/files/slcp/theta_o_5.pt new file mode 100644 index 000000000..64c3c77d4 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_5.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_6.pt b/tests/mini_sbibm/files/slcp/theta_o_6.pt new file mode 100644 index 000000000..607ced052 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_6.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_7.pt b/tests/mini_sbibm/files/slcp/theta_o_7.pt new file mode 100644 index 000000000..7da5b387b Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_7.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_8.pt b/tests/mini_sbibm/files/slcp/theta_o_8.pt new file mode 100644 index 000000000..d4ccf87d0 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_8.pt differ diff --git a/tests/mini_sbibm/files/slcp/theta_o_9.pt b/tests/mini_sbibm/files/slcp/theta_o_9.pt new file mode 100644 index 000000000..38b6ecd7b Binary files /dev/null and b/tests/mini_sbibm/files/slcp/theta_o_9.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_1.pt b/tests/mini_sbibm/files/slcp/x_o_1.pt new file mode 100644 index 000000000..f806b3ff8 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_1.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_10.pt b/tests/mini_sbibm/files/slcp/x_o_10.pt new file mode 100644 index 000000000..ddfe6c805 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_10.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_2.pt b/tests/mini_sbibm/files/slcp/x_o_2.pt new file mode 100644 index 000000000..7c7f92274 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_2.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_3.pt b/tests/mini_sbibm/files/slcp/x_o_3.pt new file mode 100644 index 000000000..de8576ce5 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_3.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_4.pt b/tests/mini_sbibm/files/slcp/x_o_4.pt new file mode 100644 index 000000000..a60abaf24 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_4.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_5.pt b/tests/mini_sbibm/files/slcp/x_o_5.pt new file mode 100644 index 000000000..a8f169e89 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_5.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_6.pt b/tests/mini_sbibm/files/slcp/x_o_6.pt new file mode 100644 index 000000000..56d43231f Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_6.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_7.pt b/tests/mini_sbibm/files/slcp/x_o_7.pt new file mode 100644 index 000000000..d88572ab9 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_7.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_8.pt b/tests/mini_sbibm/files/slcp/x_o_8.pt new file mode 100644 index 000000000..35e6f5d21 Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_8.pt differ diff --git a/tests/mini_sbibm/files/slcp/x_o_9.pt b/tests/mini_sbibm/files/slcp/x_o_9.pt new file mode 100644 index 000000000..89211253f Binary files /dev/null and b/tests/mini_sbibm/files/slcp/x_o_9.pt differ diff --git a/tests/mini_sbibm/gaussian_linear.py b/tests/mini_sbibm/gaussian_linear.py new file mode 100644 index 000000000..9dd31e7e6 --- /dev/null +++ b/tests/mini_sbibm/gaussian_linear.py @@ -0,0 +1,55 @@ +from functools import partial +from typing import Callable + +import torch +from torch.distributions import Distribution, MultivariateNormal + +from sbi.simulators.linear_gaussian import ( + diagonal_linear_gaussian, + true_posterior_linear_gaussian_mvn_prior, +) + +from .base_task import Task + + +class GaussianLinear(Task): + def __init__(self): + self.simulator_scale = 0.1 + self.dim = 5 + super().__init__("gaussian_linear") + + def theta_dim(self) -> int: + return self.dim + + def x_dim(self) -> int: + return self.dim + + def get_reference_posterior_samples(self, idx: int) -> torch.Tensor: + x_o = self.get_observation(idx) + posterior = true_posterior_linear_gaussian_mvn_prior( + x_o, + torch.zeros(self.dim), + self.simulator_scale * torch.eye(self.dim), + torch.zeros(self.dim), + torch.eye(self.dim), + ) + + return posterior.sample((10_000,)) + + def get_true_parameters(self, idx: int) -> torch.Tensor: + torch.manual_seed(idx) + return self.get_prior().sample() + + def get_observation(self, idx: int) -> torch.Tensor: + theta_o = self.get_true_parameters(idx) + x_o = self.get_simulator()(theta_o[None, :])[0] + return x_o + + def get_prior(self) -> Distribution: + return MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim)) + + def get_simulator(self) -> Callable: + return partial( + diagonal_linear_gaussian, + std=self.simulator_scale, + ) diff --git a/tests/mini_sbibm/linear_mvg.py b/tests/mini_sbibm/linear_mvg.py new file mode 100644 index 000000000..f2e842dfa --- /dev/null +++ b/tests/mini_sbibm/linear_mvg.py @@ -0,0 +1,56 @@ +from functools import partial +from typing import Callable + +import torch +from torch.distributions import Distribution, MultivariateNormal + +from sbi.simulators.linear_gaussian import ( + linear_gaussian, + true_posterior_linear_gaussian_mvn_prior, +) + +from .base_task import Task + + +class LinearMVG2d(Task): + def __init__(self): + self.likelihood_shift = torch.tensor([-1.0, 1.0]) + self.likelihood_cov = torch.tensor([[0.6, 0.5], [0.5, 0.6]]) + super().__init__("linear_mvg_2d") + + def theta_dim(self) -> int: + return 2 + + def x_dim(self) -> int: + return 2 + + def get_reference_posterior_samples(self, idx: int) -> torch.Tensor: + x_o = self.get_observation(idx) + posterior = true_posterior_linear_gaussian_mvn_prior( + x_o, + self.likelihood_shift, + self.likelihood_cov, + torch.zeros(2), + torch.eye(2), + ) + + return posterior.sample((10_000,)) + + def get_true_parameters(self, idx: int) -> torch.Tensor: + torch.manual_seed(idx) + return self.get_prior().sample() + + def get_observation(self, idx: int) -> torch.Tensor: + theta_o = self.get_true_parameters(idx) + x_o = self.get_simulator()(theta_o[None, :])[0] + return x_o + + def get_prior(self) -> Distribution: + return MultivariateNormal(torch.zeros(2), torch.eye(2)) + + def get_simulator(self) -> Callable: + return partial( + linear_gaussian, + likelihood_shift=self.likelihood_shift, + likelihood_cov=self.likelihood_cov, + ) diff --git a/tests/mini_sbibm/slcp.py b/tests/mini_sbibm/slcp.py new file mode 100644 index 000000000..b13130877 --- /dev/null +++ b/tests/mini_sbibm/slcp.py @@ -0,0 +1,52 @@ +from typing import Callable + +import torch +from torch.distributions import Distribution, Independent, MultivariateNormal, Uniform + +from .base_task import Task + + +def simulator(theta, num_data=4): + num_samples = theta.shape[0] + + m = torch.stack((theta[:, [0]].squeeze(), theta[:, [1]].squeeze())).T + if m.dim() == 1: + m.unsqueeze_(0) + + s1 = theta[:, [2]].squeeze() ** 2 + s2 = theta[:, [3]].squeeze() ** 2 + rho = torch.nn.Tanh()(theta[:, [4]]).squeeze() + + S = torch.empty((num_samples, 2, 2)) + S[:, 0, 0] = s1**2 + S[:, 0, 1] = rho * s1 * s2 + S[:, 1, 0] = rho * s1 * s2 + S[:, 1, 1] = s2**2 + + # Add eps to diagonal to ensure PSD + eps = 0.000001 + S[:, 0, 0] += eps + S[:, 1, 1] += eps + + data_dist = MultivariateNormal(m, S) + xs = data_dist.sample((num_data,)) + xs = xs.permute(1, 0, 2) + + return xs.reshape(num_samples, num_data * 2) + + +class Slcp(Task): + def __init__(self): + super().__init__("slcp") + + def theta_dim(self) -> int: + return 5 + + def x_dim(self) -> int: + return 8 + + def get_prior(self) -> Distribution: + return Independent(Uniform(-3 * torch.ones(5), 3 * torch.ones(5)), 1) + + def get_simulator(self) -> Callable: + return simulator diff --git a/tests/mini_sbibm/two_moons.py b/tests/mini_sbibm/two_moons.py index 68d7d3ae6..dc50f9cf2 100644 --- a/tests/mini_sbibm/two_moons.py +++ b/tests/mini_sbibm/two_moons.py @@ -1,5 +1,4 @@ import math -import os from typing import Callable import torch @@ -7,8 +6,6 @@ from .base_task import Task -PATH = os.path.dirname(__file__) - def _map_fun(parameters: torch.Tensor, p: torch.Tensor) -> torch.Tensor: ang = torch.tensor([-math.pi / 4.0]) diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py index cd275a7b0..3b330a91c 100644 --- a/tests/user_input_checks_test.py +++ b/tests/user_input_checks_test.py @@ -85,7 +85,7 @@ def matrix_simulator(theta): # Set default tensor locally to reach tensors in fixtures. -torch.set_default_tensor_type(torch.FloatTensor) +torch.set_default_dtype(torch.float32) @pytest.mark.parametrize(