Skip to content

Commit

Permalink
fix asktell_gen functionality test - including removing wrapper tests…
Browse files Browse the repository at this point in the history
…, since variables/objectives probably wont be passed in. remove exact H-entry test, since the gen does its own internal persis_info
  • Loading branch information
jlnav committed Nov 4, 2024
1 parent 14daf3c commit 114c7a4
Showing 1 changed file with 4 additions and 19 deletions.
23 changes: 4 additions & 19 deletions libensemble/tests/functionality_tests/test_sampling_asktell_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# Import libEnsemble items for this test
from libensemble.alloc_funcs.start_only_persistent import only_persistent_gens as alloc_f
from libensemble.gen_classes.sampling import UniformSample, UniformSampleDicts
from libensemble.gen_funcs.persistent_gen_wrapper import persistent_gen_f as gen_f
from libensemble.libE import libE
from libensemble.tools import add_unique_random_streams, parse_args

Expand Down Expand Up @@ -58,29 +57,16 @@ def sim_f(In):
alloc_specs = {"alloc_f": alloc_f}
exit_criteria = {"gen_max": 201}

for inst in range(4):
for inst in range(2):
persis_info = add_unique_random_streams({}, nworkers + 1, seed=1234)

if inst == 0:
# Using wrapper - pass class
generator = UniformSample
gen_specs["gen_f"] = gen_f
gen_specs["user"]["generator"] = generator

if inst == 1:
# Using wrapper - pass object
gen_specs["gen_f"] = gen_f
generator = UniformSample(variables, objectives, None, persis_info[1], gen_specs, None)
gen_specs["user"]["generator"] = generator
if inst == 2:
# Using asktell runner - pass object
gen_specs.pop("gen_f", None)
generator = UniformSample(variables, objectives, None, persis_info[1], gen_specs, None)
generator = UniformSample(variables, objectives)
gen_specs["generator"] = generator
if inst == 3:
if inst == 1:
# Using asktell runner - pass object - with standardized interface.
gen_specs.pop("gen_f", None)
generator = UniformSampleDicts(variables, objectives, None, persis_info[1], gen_specs, None)
generator = UniformSampleDicts(variables, objectives)
gen_specs["generator"] = generator

H, persis_info, flag = libE(
Expand All @@ -90,4 +76,3 @@ def sim_f(In):
if is_manager:
print(H[["sim_id", "x", "f"]][:10])
assert len(H) >= 201, f"H has length {len(H)}"
assert np.isclose(H["f"][9], 1.96760289)

0 comments on commit 114c7a4

Please sign in to comment.