diff --git a/src/tranquilo/options.py b/src/tranquilo/options.py index 44c2982..17b9478 100644 --- a/src/tranquilo/options.py +++ b/src/tranquilo/options.py @@ -23,6 +23,10 @@ def get_default_stagnation_options(noisy, batch_size): return out +def get_default_sample_filter(batch_size): + return "drop_excess" if batch_size > 1 else "keep_all" + + def get_default_radius_options(x): return RadiusOptions(initial_radius=0.1 * np.max(np.abs(x))) diff --git a/src/tranquilo/process_arguments.py b/src/tranquilo/process_arguments.py index 2755333..4bb9821 100644 --- a/src/tranquilo/process_arguments.py +++ b/src/tranquilo/process_arguments.py @@ -22,6 +22,7 @@ get_default_sample_size, get_default_search_radius_factor, get_default_stagnation_options, + get_default_sample_filter, update_option_bundle, NoiseAdaptationOptions, ) @@ -73,7 +74,7 @@ def process_arguments( # component names and related options sampler="optimal_hull", sampler_options=None, - sample_filter="keep_all", + sample_filter=None, sample_filter_options=None, model_fitter=None, model_fitter_options=None, @@ -156,6 +157,7 @@ def process_arguments( acceptance_decider = _process_acceptance_decider(acceptance_decider, noisy) # process options that depend on arguments with dependent defaults + sample_filter = _process_sample_filter(sample_filter, batch_size) stagnation_options = update_option_bundle( get_default_stagnation_options(noisy, batch_size=batch_size), stagnation_options ) @@ -274,6 +276,15 @@ def _process_batch_size(batch_size, n_cores): return int(batch_size) +def _process_sample_filter(sample_filter, batch_size): + if sample_filter is None: + out = get_default_sample_filter(batch_size) + else: + out = sample_filter + + return out + + def _process_sample_size(sample_size, model_type, x): if sample_size is None: out = get_default_sample_size(model_type=model_type, x=x)