From 6dd4a22ff99f5fe7f2c3fe0a3044daaf0e73c865 Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 5 Dec 2024 13:59:10 +0100 Subject: [PATCH] further improvements --- sbi/diagnostics/sbc.py | 6 +++- sbi/diagnostics/tarp.py | 4 +++ sbi/utils/diagnostics_utils.py | 55 ++++++++++++++++++++++------------ 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/sbi/diagnostics/sbc.py b/sbi/diagnostics/sbc.py index ba01fb0a8..59db21450 100644 --- a/sbi/diagnostics/sbc.py +++ b/sbi/diagnostics/sbc.py @@ -2,7 +2,7 @@ # under the Apache License Version 2.0, see import warnings -from typing import Callable, Dict, List, Union +from typing import Callable, Dict, List, Optional, Union import torch from scipy.stats import kstest, uniform @@ -26,6 +26,7 @@ def run_sbc( num_workers: int = 1, show_progress_bar: bool = True, use_batched_sampling: bool = True, + batch_size: Optional[int] = None, **kwargs, ): """Run simulation-based calibration (SBC) (parallelized across sbc runs). @@ -49,6 +50,8 @@ def run_sbc( `num_sbc_samples` inferences. show_progress_bar: whether to display a progress over sbc runs. use_batched_sampling: whether to use batched sampling for posterior samples. + batch_size: batch size for batched sampling. Useful for batched sampling with + large batches of xs for avoiding memory overflow. Returns: ranks: ranks of the ground truth parameters under the inferred @@ -89,6 +92,7 @@ def run_sbc( num_workers, show_progress_bar, use_batched_sampling=use_batched_sampling, + batch_size=batch_size, ) # take a random draw from each posterior to get data averaged posterior samples. diff --git a/sbi/diagnostics/tarp.py b/sbi/diagnostics/tarp.py index 44ff114f3..d9226fb8f 100644 --- a/sbi/diagnostics/tarp.py +++ b/sbi/diagnostics/tarp.py @@ -29,6 +29,7 @@ def run_tarp( num_bins: Optional[int] = 30, z_score_theta: bool = True, use_batched_sampling: bool = True, + batch_size: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """ Estimates coverage of samples given true values thetas with the TARP method. @@ -56,6 +57,8 @@ def run_tarp( If ``None``, then ``num_sims // 10`` bins are used. z_score_theta : whether to normalize parameters before coverage test. use_batched_sampling: whether to use batched sampling for posterior samples. + batch_size: batch size for batched sampling. Useful for batched sampling with + large batches of xs for avoiding memory overflow. Returns: ecp: Expected coverage probability (``ecp``), see equation 4 of the paper @@ -70,6 +73,7 @@ def run_tarp( num_workers, show_progress_bar=show_progress_bar, use_batched_sampling=use_batched_sampling, + batch_size=batch_size, ) assert posterior_samples.shape == ( num_posterior_samples, diff --git a/sbi/utils/diagnostics_utils.py b/sbi/utils/diagnostics_utils.py index c783316fd..b3ad038c1 100644 --- a/sbi/utils/diagnostics_utils.py +++ b/sbi/utils/diagnostics_utils.py @@ -1,4 +1,5 @@ import warnings +from typing import Optional import torch from joblib import Parallel, delayed @@ -18,6 +19,7 @@ def get_posterior_samples_on_batch( num_workers: int = 1, show_progress_bar: bool = False, use_batched_sampling: bool = True, + batch_size: Optional[int] = None, ) -> Tensor: """Get posterior samples for a batch of xs. @@ -28,22 +30,37 @@ def get_posterior_samples_on_batch( num_workers: number of workers to use for parallelization. show_progress_bars: whether to show progress bars. use_batched_sampling: whether to use batched sampling if possible. - + batch_size: batch size for batched sampling. Useful for batched sampling with + large batches of xs for avoiding memory overflow. Returns: posterior_samples: of shape (num_samples, batch_size, dim_parameters). """ - batch_size = len(xs) + num_xs = len(xs) + if batch_size is None: + batch_size = num_xs - # Try using batched sampling when implemented. - try: - # has shape (num_samples, batch_size, dim_parameters) - if use_batched_sampling: - posterior_samples = posterior.sample_batched( - sample_shape, x=xs, show_progress_bars=show_progress_bar + if use_batched_sampling: + try: + # distribute the batch of xs into smaller batches + batched_xs = xs.split(batch_size) + posterior_samples = torch.cat( + [ # has shape (num_samples, num_xs, dim_parameters) + posterior.sample_batched( + sample_shape, x=xs_batch, show_progress_bars=show_progress_bar + ) + for xs_batch in batched_xs + ], + dim=1, ) - else: - raise NotImplementedError - except (NotImplementedError, AssertionError): + except (NotImplementedError, AssertionError): + warnings.warn( + "Batched sampling not implemented for this posterior. " + "Falling back to non-batched sampling.", + stacklevel=2, + ) + use_batched_sampling = False + + if not use_batched_sampling: # We need a function with extra training step for new x for VIPosterior. def sample_fun( posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0 @@ -57,13 +74,13 @@ def sample_fun( if isinstance(posterior, (VIPosterior, MCMCPosterior)): warnings.warn( "Using non-batched sampling. Depending on the number of different xs " - f"( {batch_size}) and the number of parallel workers {num_workers}, " - "this might be slow.", + f"( {num_xs}) and the number of parallel workers {num_workers}, " + "this might take a lot of time.", stacklevel=2, ) # Run in parallel with progress bar. - seeds = torch.randint(0, 2**32, (batch_size,)) + seeds = torch.randint(0, 2**32, (num_xs,)) outputs = list( tqdm( Parallel(return_as="generator", n_jobs=num_workers)( @@ -72,7 +89,7 @@ def sample_fun( ), disable=not show_progress_bar, total=len(xs), - desc=f"Sampling {batch_size} times {sample_shape} posterior samples.", + desc=f"Sampling {num_xs} times {sample_shape} posterior samples.", ) ) # (batch_size, num_samples, dim_parameters) # Transpose to shape convention: (sample_shape, batch_size, dim_parameters) @@ -81,8 +98,8 @@ def sample_fun( ).permute(1, 0, 2) assert posterior_samples.shape[:2] == sample_shape + ( - batch_size, - ), f"""Expected batched posterior samples of shape { - sample_shape + (batch_size,) - } got {posterior_samples.shape[:2]}.""" + num_xs, + ), f"""Expected batched posterior samples of shape {sample_shape + (num_xs,)} got { + posterior_samples.shape[:2] + }.""" return posterior_samples