diff --git a/src/tranquilo/acceptance_decision.py b/src/tranquilo/acceptance_decision.py index 534b339..1b9462d 100644 --- a/src/tranquilo/acceptance_decision.py +++ b/src/tranquilo/acceptance_decision.py @@ -95,6 +95,7 @@ def accept_classic_line_search( state, history, *, + search_radius_factor, speculative_sampling_radius_factor, wrapped_criterion, min_improvement, @@ -147,7 +148,8 @@ def accept_classic_line_search( if n_unallocated_evals > 0: speculative_xs = _generate_speculative_sample( new_center=candidate_x, - radius_factor=speculative_sampling_radius_factor, + sample_radius_factor=speculative_sampling_radius_factor, + search_radius_factor=search_radius_factor, trustregion=state.trustregion, sample_points=sample_points, n_points=n_unallocated_evals, @@ -437,7 +439,8 @@ def _generate_speculative_sample( n_points, history, line_search_xs, - radius_factor, + search_radius_factor, + sample_radius_factor, rng, ): """Generative a speculative sample. @@ -448,8 +451,10 @@ def _generate_speculative_sample( sample_points (callable): Function to sample points. n_points (int): Number of points to sample. history (History): Tranquilo history. - radius_factor (float): Factor to multiply the trust region radius by to get the - radius of the region from which to draw the speculative sample. + search_radius_factor (float): Factor to multiply the trust region radius by to + get the radius of the search region. + sample_radius_factor (float): Factor to multiply the trust region radius by to + get the radius of the region from which to draw the speculative sample. rng (np.random.Generator): Random number generator. Returns: @@ -457,7 +462,10 @@ def _generate_speculative_sample( """ search_region = trustregion._replace( - center=new_center, radius=radius_factor * trustregion.radius + center=new_center, radius=search_radius_factor * trustregion.radius + ) + sample_region = trustregion._replace( + center=new_center, radius=sample_radius_factor * trustregion.radius ) old_indices = history.get_x_indices_in_region(search_region) @@ -470,7 +478,7 @@ def _generate_speculative_sample( model_xs = old_xs new_xs = sample_points( - search_region, + sample_region, n_points=n_points, existing_xs=model_xs, rng=rng, diff --git a/tests/test_acceptance_decision.py b/tests/test_acceptance_decision.py index 2be2bab..ba1f59b 100644 --- a/tests/test_acceptance_decision.py +++ b/tests/test_acceptance_decision.py @@ -185,7 +185,8 @@ def test_generate_speculative_sample(): sample_points=get_sampler("random_hull"), n_points=3, history=history, - radius_factor=1.0, + search_radius_factor=5.0, + sample_radius_factor=1.0, line_search_xs=None, rng=np.random.default_rng(1234), )