Skip to content

Commit

Permalink
extended to something reasonable
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Dec 19, 2024
1 parent a4f060d commit 6fa9b96
Show file tree
Hide file tree
Showing 39 changed files with 382 additions and 44 deletions.
50 changes: 50 additions & 0 deletions test_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from datetime import datetime
import pytest

import shutil


def pytest_addoption(parser):
parser.addoption(
"--print-harvest",
action="store_true",
default=False,
help="Print the harvest results at the end of the test session",
)


def pytest_terminal_summary(terminalreporter, exitstatus, config):
# Only print the harvest results if the --print-harvest flag is used
if config.getoption("--print-harvest"):
# Dynamically center the summary title in the terminal width
terminal_width = shutil.get_terminal_size().columns
summary_text = " short test summary info "
centered_line = summary_text.center(terminal_width, '=')
terminalreporter.write_line(centered_line)


@pytest.mark.parametrize('p', ['world', 'self'], ids=str)
def test_foo(p, results_bag):
"""
A dummy test, parametrized so that it is executed twice
"""

# Let's store some things in the results bag
results_bag.nb_letters = len(p)
results_bag.current_time = datetime.now().isoformat()


def test_synthesis(fixture_store):
"""
In this test we inspect the contents of the fixture store so far, and
check that the 'results_bag' entry contains a dict <test_id>: <results_bag>
"""
# print the keys in the store
results = fixture_store["results_bag"]

# print what is available for the 'results_bag' entry
print("\n--- Harvested Test Results ---")
for k, v in results.items():
print(k)
for kk, vv in v.items():
print(kk, vv)
146 changes: 108 additions & 38 deletions tests/bm_test.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import pytest
import torch

from sbi.inference import NPE, NRE
from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
from sbi.utils.metrics import c2st

from .mini_sbibm import get_task

# The probably should be some user control on this
SEED = 0
TASKS = ["two_moons", "linear_mvg_2d", "gaussian_linear", "slcp"]
NUM_SIMULATIONS = 2000
EVALUATION_POINTS = 4 # Currently only 3 observation tested for speed

@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', ['two_moons'], ids=str)
@pytest.mark.parametrize('density_estimator', ["maf", "nsf"], ids=str)
def test_benchmark_npe_methods(
task_name, density_estimator, results_bag, method=None, num_simulations=1000, seed=0
):
torch.manual_seed(seed)
task = get_task(task_name)
thetas, xs = task.get_data(num_simulations)
assert thetas.shape[0] == num_simulations
assert xs.shape[0] == num_simulations
TRAIN_KWARGS = {
# "training_batch_size": 200, # To speed up training
}

inference = NPE(density_estimator=density_estimator)
_ = inference.append_simulations(thetas, xs).train()
# Amortized benchmarking

posterior = inference.build_posterior()

def standard_eval_c2st_loop(posterior, task):
metrics = []
for i in range(1, 2): # Currently only one observation tested for speed
for i in range(1, EVALUATION_POINTS):
x_o = task.get_observation(i)
posterior_samples = task.get_reference_posterior_samples(i)
approx_posterior_samples = posterior.sample((1000,), x=x_o)
Expand All @@ -37,43 +33,117 @@ def test_benchmark_npe_methods(
mean_c2st = sum(metrics) / len(metrics)
# Convert to float rounded to 3 decimal places
mean_c2st = float(f"{mean_c2st:.3f}")
return mean_c2st


DENSITY_estimators = ["mdn", "made", "maf", "nsf", "maf_rqs"] # "Kinda exhaustive"
DENSITY_estimators = ["maf", "nsf"] # Fast


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', TASKS, ids=str)
@pytest.mark.parametrize('density_estimator', DENSITY_estimators, ids=str)
def test_benchmark_npe_methods(task_name, density_estimator, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()

print(thetas.shape, xs.shape)

inference = NPE(prior, density_estimator=density_estimator)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

mean_c2st = standard_eval_c2st_loop(posterior, task)

# Cache results
results_bag.metric = mean_c2st
results_bag.num_simulations = num_simulations
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "NPE_" + density_estimator


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', ['two_moons'], ids=str)
def test_benchmark_nre_methods(task_name, results_bag, num_simulations=1000, seed=0):
torch.manual_seed(seed)
@pytest.mark.parametrize('task_name', TASKS, ids=str)
def test_benchmark_nre_methods(task_name, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(num_simulations)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()
assert thetas.shape[0] == num_simulations
assert xs.shape[0] == num_simulations

inference = NRE(prior)
_ = inference.append_simulations(thetas, xs).train()
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

metrics = []
for i in range(1, 2):
x_o = task.get_observation(i)
posterior_samples = task.get_reference_posterior_samples(i)
approx_posterior_samples = posterior.sample((1000,), x=x_o)
if isinstance(approx_posterior_samples, tuple):
approx_posterior_samples = approx_posterior_samples[0]
c2st_val = c2st(posterior_samples[:1000], approx_posterior_samples)
metrics.append(c2st_val)

mean_c2st = sum(metrics) / len(metrics)
# Convert to float rounded to 3 decimal places
mean_c2st = float(f"{mean_c2st:.3f}")
mean_c2st = standard_eval_c2st_loop(posterior, task)

results_bag.metric = mean_c2st
results_bag.num_simulations = num_simulations
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "NRE"


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', TASKS, ids=str)
def test_benchmark_nle_methods(task_name, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()

inference = NLE(prior)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

mean_c2st = standard_eval_c2st_loop(posterior, task)

results_bag.metric = mean_c2st
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "NLE"


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', TASKS, ids=str)
def test_benchmark_fmpe_methods(task_name, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()

inference = FMPE(prior)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

mean_c2st = standard_eval_c2st_loop(posterior, task)

results_bag.metric = mean_c2st
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "FMPE"


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', TASKS, ids=str)
def test_benchmark_npse_methods(task_name, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()

inference = NPSE(prior)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

mean_c2st = standard_eval_c2st_loop(posterior, task)

results_bag.metric = mean_c2st
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "NPSE"
53 changes: 51 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import pickle
import shutil
from logging import warning
from pathlib import Path
from shutil import rmtree

import pytest
import torch
Expand Down Expand Up @@ -79,7 +83,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
terminalreporter.write_line(colored_line)

if harvested_fixture_data is not None:
terminalreporter.write_line("Harvested Fixture Data:")
terminalreporter.write_line("Amortized inference:")

results = harvested_fixture_data["results_bag"]

Expand Down Expand Up @@ -131,7 +135,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
val = data.get((m, t), "N/A")
# Convert metric to string with formatting if needed
# e.g. format(val, ".3f") if val is a float
val_str = str(val)
val_str = format(val, ".3f")
row += val_str.center(task_col_widths[t] + 2)
terminalreporter.write_line(row)

Expand All @@ -149,3 +153,48 @@ def mcmc_params_accurate() -> dict:
def mcmc_params_fast() -> dict:
"""Fixture for MCMC parameters for fast tests."""
return dict(num_chains=1, thin=1, warmup_steps=1)


# Pytest harvest xdist support - not sure if we need it (for me xdist is always slower).


# Define the folder in which temporary worker's results will be stored
RESULTS_PATH = Path('./.xdist_results/')
RESULTS_PATH.mkdir(exist_ok=True)


def pytest_harvest_xdist_init():
# reset the recipient folder
if RESULTS_PATH.exists():
rmtree(RESULTS_PATH)
RESULTS_PATH.mkdir(exist_ok=False)
return True


def pytest_harvest_xdist_worker_dump(worker_id, session_items, fixture_store):
# persist session_items and fixture_store in the file system
with open(RESULTS_PATH / ('%s.pkl' % worker_id), 'wb') as f:
try:
pickle.dump((session_items, fixture_store), f)
except Exception as e:
warning(
"Error while pickling worker %s's harvested results: " "[%s] %s",
(worker_id, e.__class__, e),
)
return True


def pytest_harvest_xdist_load():
# restore the saved objects from file system
workers_saved_material = dict()
for pkl_file in RESULTS_PATH.glob('*.pkl'):
wid = pkl_file.stem
with pkl_file.open('rb') as f:
workers_saved_material[wid] = pickle.load(f)
return workers_saved_material


def pytest_harvest_xdist_cleanup():
# delete all temporary pickle files
rmtree(RESULTS_PATH)
return True
9 changes: 9 additions & 0 deletions tests/mini_sbibm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from .gaussian_linear import GaussianLinear
from .linear_mvg import LinearMVG2d
from .slcp import Slcp
from .two_moons import TwoMoons


def get_task(name: str):
if name == "two_moons":
return TwoMoons()
elif name == "linear_mvg_2d":
return LinearMVG2d()
elif name == "gaussian_linear":
return GaussianLinear()
elif name == "slcp":
return Slcp()
else:
raise ValueError(f"Unknown task {name}")
Binary file added tests/mini_sbibm/files/slcp/samples_1.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_10.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_2.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_3.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_4.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_5.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_6.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_7.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_8.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_9.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_1.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_10.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_2.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_3.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_4.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_5.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_6.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_7.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_8.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_9.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_1.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_10.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_2.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_3.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_4.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_5.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_6.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_7.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_8.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_9.pt
Binary file not shown.
55 changes: 55 additions & 0 deletions tests/mini_sbibm/gaussian_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from functools import partial
from typing import Callable

import torch
from torch.distributions import Distribution, MultivariateNormal

from sbi.simulators.linear_gaussian import (
diagonal_linear_gaussian,
true_posterior_linear_gaussian_mvn_prior,
)

from .base_task import Task


class GaussianLinear(Task):
def __init__(self):
self.simulator_scale = 0.1
self.dim = 5
super().__init__("gaussian_linear")

def theta_dim(self) -> int:
return self.dim

def x_dim(self) -> int:
return self.dim

def get_reference_posterior_samples(self, idx: int) -> torch.Tensor:
x_o = self.get_observation(idx)
posterior = true_posterior_linear_gaussian_mvn_prior(
x_o,
torch.zeros(self.dim),
self.simulator_scale * torch.eye(self.dim),
torch.zeros(self.dim),
torch.eye(self.dim),
)

return posterior.sample((10_000,))

def get_true_parameters(self, idx: int) -> torch.Tensor:
torch.manual_seed(idx)
return self.get_prior().sample()

def get_observation(self, idx: int) -> torch.Tensor:
theta_o = self.get_true_parameters(idx)
x_o = self.get_simulator()(theta_o[None, :])[0]
return x_o

def get_prior(self) -> Distribution:
return MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim))

def get_simulator(self) -> Callable:
return partial(
diagonal_linear_gaussian,
std=self.simulator_scale,
)
Loading

0 comments on commit 6fa9b96

Please sign in to comment.