Skip to content

Commit

Permalink
Add search and sample region in line search acceptance step
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jan 19, 2024
1 parent 1cacab1 commit ee26d56
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
20 changes: 14 additions & 6 deletions src/tranquilo/acceptance_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def accept_classic_line_search(
state,
history,
*,
search_radius_factor,
speculative_sampling_radius_factor,
wrapped_criterion,
min_improvement,
Expand Down Expand Up @@ -147,7 +148,8 @@ def accept_classic_line_search(
if n_unallocated_evals > 0:
speculative_xs = _generate_speculative_sample(
new_center=candidate_x,
radius_factor=speculative_sampling_radius_factor,
sample_radius_factor=speculative_sampling_radius_factor,
search_radius_factor=search_radius_factor,
trustregion=state.trustregion,
sample_points=sample_points,
n_points=n_unallocated_evals,
Expand Down Expand Up @@ -437,7 +439,8 @@ def _generate_speculative_sample(
n_points,
history,
line_search_xs,
radius_factor,
search_radius_factor,
sample_radius_factor,
rng,
):
"""Generative a speculative sample.
Expand All @@ -448,16 +451,21 @@ def _generate_speculative_sample(
sample_points (callable): Function to sample points.
n_points (int): Number of points to sample.
history (History): Tranquilo history.
radius_factor (float): Factor to multiply the trust region radius by to get the
radius of the region from which to draw the speculative sample.
search_radius_factor (float): Factor to multiply the trust region radius by to
get the radius of the search region.
sample_radius_factor (float): Factor to multiply the trust region radius by to
get the radius of the region from which to draw the speculative sample.
rng (np.random.Generator): Random number generator.
Returns:
np.ndarray: Speculative sample.
"""
search_region = trustregion._replace(
center=new_center, radius=radius_factor * trustregion.radius
center=new_center, radius=search_radius_factor * trustregion.radius
)
sample_region = trustregion._replace(
center=new_center, radius=sample_radius_factor * trustregion.radius
)

old_indices = history.get_x_indices_in_region(search_region)
Expand All @@ -470,7 +478,7 @@ def _generate_speculative_sample(
model_xs = old_xs

new_xs = sample_points(
search_region,
sample_region,
n_points=n_points,
existing_xs=model_xs,
rng=rng,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_acceptance_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def test_generate_speculative_sample():
sample_points=get_sampler("random_hull"),
n_points=3,
history=history,
radius_factor=1.0,
search_radius_factor=5.0,
sample_radius_factor=1.0,
line_search_xs=None,
rng=np.random.default_rng(1234),
)
Expand Down

0 comments on commit ee26d56

Please sign in to comment.