diff --git a/.codecov.yml b/.codecov.yml index 18ef40801..f99839378 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -5,3 +5,4 @@ ignore: - "libensemble/sim_funcs/executor_hworld.py" - "libensemble/gen_funcs/persistent_ax_multitask.py" - "libensemble/gen_funcs/persistent_gpCAM.py" + - "libensemble/gen_classes/gpCAM.py" diff --git a/.flake8 b/.flake8 index d49bc0d3b..c21368b65 100644 --- a/.flake8 +++ b/.flake8 @@ -40,6 +40,7 @@ per-file-ignores = libensemble/tests/scaling_tests/warpx/run_libensemble_on_warpx.py:E402 examples/calling_scripts/run_libensemble_on_warpx.py:E402 libensemble/tests/regression_tests/test_persistent_aposmm*:E402 + libensemble/tests/regression_tests/test_asktell_aposmm_nlopt.py:E402 libensemble/tests/regression_tests/test_persistent_gp_multitask_ax.py:E402 libensemble/tests/functionality_tests/test_uniform_sampling_then_persistent_localopt_runs.py:E402 libensemble/tests/functionality_tests/test_stats_output.py:E402 diff --git a/.gitignore b/.gitignore index 828a6fff6..c6bd3c0dd 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ dist/ .spyproject/ .hypothesis +.pixi diff --git a/docs/function_guides/ask_tell_generator.rst b/docs/function_guides/ask_tell_generator.rst new file mode 100644 index 000000000..6212b24f5 --- /dev/null +++ b/docs/function_guides/ask_tell_generator.rst @@ -0,0 +1,21 @@ + +Ask/Tell Generators +=================== + +**BETA - SUBJECT TO CHANGE** + +These generators, implementations, methods, and subclasses are in BETA, and +may change in future releases. + +The Generator interface is expected to roughly correspond with CAMPA's standard: +https://github.com/campa-consortium/generator_standard + +libEnsemble is in the process of supporting generator objects that implement the following interface: + +.. automodule:: generators + :members: Generator LibensembleGenerator + :undoc-members: + +.. autoclass:: Generator + :member-order: bysource + :members: diff --git a/docs/function_guides/function_guide_index.rst b/docs/function_guides/function_guide_index.rst index 621bf36d2..0539e24c6 100644 --- a/docs/function_guides/function_guide_index.rst +++ b/docs/function_guides/function_guide_index.rst @@ -13,6 +13,7 @@ These guides describe common development patterns and optional components: :caption: Writing User Functions generator + ask_tell_generator simulator allocator sim_gen_alloc_api diff --git a/libensemble/__init__.py b/libensemble/__init__.py index 605336821..8df3af207 100644 --- a/libensemble/__init__.py +++ b/libensemble/__init__.py @@ -12,3 +12,4 @@ from libensemble import logger from .ensemble import Ensemble +from .generators import Generator diff --git a/libensemble/comms/comms.py b/libensemble/comms/comms.py index 51042c463..d8d892319 100644 --- a/libensemble/comms/comms.py +++ b/libensemble/comms/comms.py @@ -226,6 +226,7 @@ def _qcomm_main(comm, main, *args, **kwargs): if not kwargs.get("user_function"): _result = main(comm, *args, **kwargs) else: + # SH - could we insert comm into libE_info["comm"] here if it exists _result = main(*args) comm.send(CommResult(_result)) except Exception as e: @@ -264,8 +265,8 @@ def __init__(self, main, nworkers, *args, **kwargs): self.inbox = Queue() self.outbox = Queue() super().__init__(self, main, *args, **kwargs) - comm = QComm(self.inbox, self.outbox, nworkers) - self.handle = Process(target=_qcomm_main, args=(comm, main) + args, kwargs=kwargs) + self.comm = QComm(self.inbox, self.outbox, nworkers) + self.handle = Process(target=_qcomm_main, args=(self.comm, main) + args, kwargs=kwargs) def terminate(self, timeout=None): """Terminate the process.""" diff --git a/libensemble/executors/mpi_runner.py b/libensemble/executors/mpi_runner.py index eb002d14b..680f6a086 100644 --- a/libensemble/executors/mpi_runner.py +++ b/libensemble/executors/mpi_runner.py @@ -21,7 +21,7 @@ def get_runner(mpi_runner_type, runner_name=None, platform_info=None): "msmpi": MSMPI_MPIRunner, "custom": MPIRunner, } - mpi_runner = mpi_runners[mpi_runner_type] + mpi_runner = mpi_runners.get(mpi_runner_type, MPIRunner) if runner_name is not None: runner = mpi_runner(run_command=runner_name, platform_info=platform_info) else: diff --git a/libensemble/gen_classes/__init__.py b/libensemble/gen_classes/__init__.py new file mode 100644 index 000000000..f33c2ebc0 --- /dev/null +++ b/libensemble/gen_classes/__init__.py @@ -0,0 +1,2 @@ +from .aposmm import APOSMM # noqa: F401 +from .sampling import UniformSample, UniformSampleDicts # noqa: F401 diff --git a/libensemble/gen_classes/aposmm.py b/libensemble/gen_classes/aposmm.py new file mode 100644 index 000000000..1cb802173 --- /dev/null +++ b/libensemble/gen_classes/aposmm.py @@ -0,0 +1,125 @@ +import copy +from typing import List + +import numpy as np +from numpy import typing as npt + +from libensemble.generators import LibensembleGenThreadInterfacer +from libensemble.message_numbers import EVAL_GEN_TAG, PERSIS_STOP + + +class APOSMM(LibensembleGenThreadInterfacer): + """ + Standalone object-oriented APOSMM generator + """ + + def __init__( + self, + variables: dict, + objectives: dict, + History: npt.NDArray = [], + persis_info: dict = {}, + gen_specs: dict = {}, + libE_info: dict = {}, + **kwargs + ) -> None: + from libensemble.gen_funcs.persistent_aposmm import aposmm + + self.variables = variables + self.objectives = objectives + + gen_specs["gen_f"] = aposmm + + if not gen_specs.get("out"): # gen_specs never especially changes for aposmm even as the problem varies + if not self.variables: + self.n = len(kwargs["lb"]) or len(kwargs["ub"]) + else: + self.n = len(self.variables) + gen_specs["out"] = [ + ("x", float, self.n), + ("x_on_cube", float, self.n), + ("sim_id", int), + ("local_min", bool), + ("local_pt", bool), + ] + gen_specs["persis_in"] = ["x", "f", "local_pt", "sim_id", "sim_ended", "x_on_cube", "local_min"] + super().__init__(variables, objectives, History, persis_info, gen_specs, libE_info, **kwargs) + if not self.persis_info.get("nworkers"): + self.persis_info["nworkers"] = kwargs.get("nworkers", gen_specs["user"]["max_active_runs"]) + self.all_local_minima = [] + self._ask_idx = 0 + self._last_ask = None + self._tell_buf = None + self._n_buffd_results = 0 + self._told_initial_sample = False + + def _slot_in_data(self, results): + """Slot in libE_calc_in and trial data into corresponding array fields. *Initial sample only!!*""" + self._tell_buf[self._n_buffd_results : self._n_buffd_results + len(results)] = results + + def _enough_initial_sample(self): + return ( + self._n_buffd_results >= int(self.gen_specs["user"]["initial_sample_size"]) + ) or self._told_initial_sample + + def _ready_to_ask_genf(self): + """ + We're presumably ready to be asked IF: + - When we're working on the initial sample: + - We have no _last_ask cached + - all points given out have returned AND we've been asked *at least* as many points as we cached + - When we're done with the initial sample: + - we've been asked *at least* as many points as we cached + """ + if not self._told_initial_sample and self._last_ask is not None: + cond = all([i in self._tell_buf["sim_id"] for i in self._last_ask["sim_id"]]) + else: + cond = True + return self._last_ask is None or (cond and (self._ask_idx >= len(self._last_ask))) + + def ask_numpy(self, num_points: int = 0) -> npt.NDArray: + """Request the next set of points to evaluate, as a NumPy array.""" + if self._ready_to_ask_genf(): + self._ask_idx = 0 + self._last_ask = super().ask_numpy(num_points) + + if self._last_ask["local_min"].any(): # filter out local minima rows + min_idxs = self._last_ask["local_min"] + self.all_local_minima.append(self._last_ask[min_idxs]) + self._last_ask = self._last_ask[~min_idxs] + + if num_points > 0: # we've been asked for a selection of the last ask + results = np.copy(self._last_ask[self._ask_idx : self._ask_idx + num_points]) + self._ask_idx += num_points + + else: + results = np.copy(self._last_ask) + self._last_ask = None + + return results + + def tell_numpy(self, results: npt.NDArray, tag: int = EVAL_GEN_TAG) -> None: + if (results is None and tag == PERSIS_STOP) or self._told_initial_sample: + super().tell_numpy(results, tag) + return + + # Initial sample buffering here: + + if self._n_buffd_results == 0: + self._tell_buf = np.zeros(self.gen_specs["user"]["initial_sample_size"], dtype=results.dtype) + self._tell_buf["sim_id"] = -1 + + if not self._enough_initial_sample(): + self._slot_in_data(np.copy(results)) + self._n_buffd_results += len(results) + + if self._enough_initial_sample(): + super().tell_numpy(self._tell_buf, tag) + self._told_initial_sample = True + self._n_buffd_results = 0 + + def ask_updates(self) -> List[npt.NDArray]: + """Request a list of NumPy arrays containing entries that have been identified as minima.""" + minima = copy.deepcopy(self.all_local_minima) + self.all_local_minima = [] + return minima diff --git a/libensemble/gen_classes/gpCAM.py b/libensemble/gen_classes/gpCAM.py new file mode 100644 index 000000000..884832980 --- /dev/null +++ b/libensemble/gen_classes/gpCAM.py @@ -0,0 +1,157 @@ +"""Generator class exposing gpCAM functionality""" + +import time +from typing import List + +import numpy as np +from gpcam import GPOptimizer as GP +from numpy import typing as npt + +# While there are class / func duplicates - re-use functions. +from libensemble.gen_funcs.persistent_gpCAM import ( + _calculate_grid_distances, + _eval_var, + _find_eligible_points, + _generate_mesh, + _read_testpoints, +) +from libensemble.generators import LibensembleGenerator + +__all__ = [ + "GP_CAM", + "GP_CAM_Covar", +] + + +# Note - batch size is set in wrapper currently - and passed to ask as n_trials. +# To support empty ask(), add batch_size back in here. + + +# Equivalent to function persistent_gpCAM_ask_tell +class GP_CAM(LibensembleGenerator): + """ + This generation function constructs a global surrogate of `f` values. + + It is a batched method that produces a first batch uniformly random from + (lb, ub). On subequent iterations, it calls an optimization method to + produce the next batch of points. This optimization might be too slow + (relative to the simulation evaluation time) for some use cases. + """ + + def _initialize_gpCAM(self, user_specs): + """Extract user params""" + # self.b = user_specs["batch_size"] + self.lb = np.array(user_specs["lb"]) + self.ub = np.array(user_specs["ub"]) + self.n = len(self.lb) # dimension + assert isinstance(self.n, int), "Dimension must be an integer" + assert isinstance(self.lb, np.ndarray), "lb must be a numpy array" + assert isinstance(self.ub, np.ndarray), "ub must be a numpy array" + self.all_x = np.empty((0, self.n)) + self.all_y = np.empty((0, 1)) + np.random.seed(0) + + def __init__(self, H, persis_info, gen_specs, libE_info=None): + self.H = H # Currently not used - could be used for an H0 + self.persis_info = persis_info + self.gen_specs = gen_specs + self.libE_info = libE_info + + self.U = self.gen_specs["user"] + self._initialize_gpCAM(self.U) + self.rng = self.persis_info["rand_stream"] + + self.my_gp = None + self.noise = 1e-8 # 1e-12 + self.ask_max_iter = self.gen_specs["user"].get("ask_max_iter") or 10 + + def ask_numpy(self, n_trials: int) -> npt.NDArray: + if self.all_x.shape[0] == 0: + self.x_new = self.rng.uniform(self.lb, self.ub, (n_trials, self.n)) + else: + start = time.time() + self.x_new = self.my_gp.ask( + input_set=np.column_stack((self.lb, self.ub)), + n=n_trials, + pop_size=n_trials, + acquisition_function="total correlation", + max_iter=self.ask_max_iter, # Larger takes longer. gpCAM default is 20. + )["x"] + print(f"Ask time:{time.time() - start}") + H_o = np.zeros(n_trials, dtype=self.gen_specs["out"]) + H_o["x"] = self.x_new + return H_o + + def tell_numpy(self, calc_in: npt.NDArray) -> None: + if calc_in is not None: + if "x" in calc_in.dtype.names: # SH should we require x in? + self.x_new = np.atleast_2d(calc_in["x"]) + self.y_new = np.atleast_2d(calc_in["f"]).T + nan_indices = [i for i, fval in enumerate(self.y_new) if np.isnan(fval[0])] + self.x_new = np.delete(self.x_new, nan_indices, axis=0) + self.y_new = np.delete(self.y_new, nan_indices, axis=0) + + self.all_x = np.vstack((self.all_x, self.x_new)) + self.all_y = np.vstack((self.all_y, self.y_new)) + + noise_var = self.noise * np.ones(len(self.all_y)) + if self.my_gp is None: + self.my_gp = GP(self.all_x, self.all_y.flatten(), noise_variances=noise_var) + else: + self.my_gp.tell(self.all_x, self.all_y.flatten(), noise_variances=noise_var) + self.my_gp.train() + + +class GP_CAM_Covar(GP_CAM): + """ + This generation function constructs a global surrogate of `f` values. + + It is a batched method that produces a first batch uniformly random from + (lb, ub) and on following iterations samples the GP posterior covariance + function to find sample points. + """ + + def __init__(self, H, persis_info, gen_specs, libE_info=None): + super().__init__(H, persis_info, gen_specs, libE_info) + self.test_points = _read_testpoints(self.U) + self.x_for_var = None + self.var_vals = None + if self.U.get("use_grid"): + self.num_points = 10 + self.x_for_var = _generate_mesh(self.lb, self.ub, self.num_points) + self.r_low_init, self.r_high_init = _calculate_grid_distances(self.lb, self.ub, self.num_points) + + def ask_numpy(self, n_trials: int) -> List[dict]: + if self.all_x.shape[0] == 0: + x_new = self.rng.uniform(self.lb, self.ub, (n_trials, self.n)) + else: + if not self.U.get("use_grid"): + x_new = self.x_for_var[np.argsort(self.var_vals)[-n_trials:]] + else: + r_high = self.r_high_init + r_low = self.r_low_init + x_new = [] + r_cand = r_high # Let's start with a large radius and stop when we have batchsize points + + sorted_indices = np.argsort(-self.var_vals) + while len(x_new) < n_trials: + x_new = _find_eligible_points(self.x_for_var, sorted_indices, r_cand, n_trials) + if len(x_new) < n_trials: + r_high = r_cand + r_cand = (r_high + r_low) / 2.0 + + self.x_new = x_new + H_o = np.zeros(n_trials, dtype=self.gen_specs["out"]) + H_o["x"] = self.x_new + return H_o + + def tell_numpy(self, calc_in: npt.NDArray): + if calc_in is not None: + super().tell_numpy(calc_in) + if not self.U.get("use_grid"): + n_trials = len(self.y_new) + self.x_for_var = self.rng.uniform(self.lb, self.ub, (10 * n_trials, self.n)) + + self.var_vals = _eval_var( + self.my_gp, self.all_x, self.all_y, self.x_for_var, self.test_points, self.persis_info + ) diff --git a/libensemble/gen_classes/sampling.py b/libensemble/gen_classes/sampling.py new file mode 100644 index 000000000..35a075e22 --- /dev/null +++ b/libensemble/gen_classes/sampling.py @@ -0,0 +1,78 @@ +"""Generator classes providing points using sampling""" + +import numpy as np + +from libensemble.generators import Generator, LibensembleGenerator +from libensemble.utils.misc import list_dicts_to_np + +__all__ = [ + "UniformSample", + "UniformSampleDicts", +] + + +class SampleBase(LibensembleGenerator): + """Base class for sampling generators""" + + def _get_user_params(self, user_specs): + """Extract user params""" + self.ub = user_specs["ub"] + self.lb = user_specs["lb"] + self.n = len(self.lb) # dimension + assert isinstance(self.n, int), "Dimension must be an integer" + assert isinstance(self.lb, np.ndarray), "lb must be a numpy array" + assert isinstance(self.ub, np.ndarray), "ub must be a numpy array" + + +class UniformSample(SampleBase): + """ + This generator returns ``gen_specs["initial_batch_size"]`` uniformly + sampled points the first time it is called. Afterwards, it returns the + number of points given. This can be used in either a batch or asynchronous + mode by adjusting the allocation function. + """ + + def __init__(self, variables: dict, objectives: dict, _=[], persis_info={}, gen_specs={}, libE_info=None, **kwargs): + super().__init__(variables, objectives, _, persis_info, gen_specs, libE_info, **kwargs) + self._get_user_params(self.gen_specs["user"]) + + def ask_numpy(self, n_trials): + return list_dicts_to_np( + UniformSampleDicts( + self.variables, self.objectives, self.History, self.persis_info, self.gen_specs, self.libE_info + ).ask(n_trials) + ) + + def tell_numpy(self, calc_in): + pass # random sample so nothing to tell + + +# List of dictionaries format for ask (constructor currently using numpy still) +# Mostly standard generator interface for libE generators will use the ask/tell wrappers +# to the classes above. This is for testing a function written directly with that interface. +class UniformSampleDicts(Generator): + """ + This generator returns ``gen_specs["initial_batch_size"]`` uniformly + sampled points the first time it is called. Afterwards, it returns the + number of points given. This can be used in either a batch or asynchronous + mode by adjusting the allocation function. + + This currently adheres to the complete standard. + """ + + def __init__(self, variables: dict, objectives: dict, _, persis_info, gen_specs, libE_info=None, **kwargs): + self.variables = variables + self.gen_specs = gen_specs + self.persis_info = persis_info + + def ask(self, n_trials): + H_o = [] + for _ in range(n_trials): + trial = {} + for key in self.variables.keys(): + trial[key] = self.persis_info["rand_stream"].uniform(self.variables[key][0], self.variables[key][1]) + H_o.append(trial) + return H_o + + def tell(self, calc_in): + pass # random sample so nothing to tell diff --git a/libensemble/gen_funcs/aposmm_localopt_support.py b/libensemble/gen_funcs/aposmm_localopt_support.py index 0bd1b9f3c..499bc38d5 100644 --- a/libensemble/gen_funcs/aposmm_localopt_support.py +++ b/libensemble/gen_funcs/aposmm_localopt_support.py @@ -683,7 +683,7 @@ def put_set_wait_get(x, comm_queue, parent_can_read, child_can_read, user_specs) if user_specs.get("periodic"): assert np.allclose(x % 1, values[0] % 1, rtol=1e-15, atol=1e-15), "The point I gave is not the point I got back" else: - assert np.allclose(x, values[0], rtol=1e-15, atol=1e-15), "The point I gave is not the point I got back" + assert np.allclose(x, values[0], rtol=1e-8, atol=1e-8), "The point I gave is not the point I got back" return values diff --git a/libensemble/generators.py b/libensemble/generators.py new file mode 100644 index 000000000..d8cb06cb8 --- /dev/null +++ b/libensemble/generators.py @@ -0,0 +1,251 @@ +# import queue as thread_queue +from abc import ABC, abstractmethod + +# from multiprocessing import Queue as process_queue +from typing import List, Optional + +import numpy as np +from numpy import typing as npt + +from libensemble.comms.comms import QCommProcess # , QCommThread +from libensemble.executors import Executor +from libensemble.message_numbers import EVAL_GEN_TAG, PERSIS_STOP +from libensemble.tools.tools import add_unique_random_streams +from libensemble.utils.misc import list_dicts_to_np, np_to_list_dicts + +""" +NOTE: These generators, implementations, methods, and subclasses are in BETA, and + may change in future releases. + + The Generator interface is expected to roughly correspond with CAMPA's standard: + https://github.com/campa-consortium/generator_standard +""" + + +class GeneratorNotStartedException(Exception): + """Exception raised by a threaded/multiprocessed generator upon being asked without having been started""" + + +class Generator(ABC): + """ + + .. code-block:: python + + from libensemble.specs import GenSpecs + from libensemble.generators import Generator + + + class MyGenerator(Generator): + def __init__(self, variables, objectives, param): + self.param = param + self.model = create_model(variables, objectives, self.param) + + def ask(self, num_points): + return create_points(num_points, self.param) + + def tell(self, results): + self.model = update_model(results, self.model) + + def final_tell(self, results): + self.tell(results) + return list(self.model) + + + variables = {"a": [-1, 1], "b": [-2, 2]} + objectives = {"f": "MINIMIZE"} + + my_generator = MyGenerator(variables, objectives, my_parameter=100) + gen_specs = GenSpecs(generator=my_generator, ...) + """ + + @abstractmethod + def __init__(self, variables: dict[str, List[float]], objectives: dict[str, str], *args, **kwargs): + """ + Initialize the Generator object on the user-side. Constants, class-attributes, + and preparation goes here. + + .. code-block:: python + + my_generator = MyGenerator(my_parameter, batch_size=10) + """ + + @abstractmethod + def ask(self, num_points: Optional[int]) -> List[dict]: + """ + Request the next set of points to evaluate. + """ + + def ask_updates(self) -> List[npt.NDArray]: + """ + Request any updates to previous points, e.g. minima discovered, points to cancel. + """ + + def tell(self, results: List[dict]) -> None: + """ + Send the results of evaluations to the generator. + """ + + def final_tell(self, results: List[dict], *args, **kwargs) -> Optional[npt.NDArray]: + """ + Send the last set of results to the generator, instruct it to cleanup, and + optionally retrieve an updated final state of evaluations. This is a separate + method to simplify the common pattern of noting internally if a + specific tell is the last. This will be called only once. + """ + + +class LibensembleGenerator(Generator): + """Internal implementation of Generator interface for use with libEnsemble, or for those who + prefer numpy arrays. ``ask/tell`` methods communicate lists of dictionaries, like the standard. + ``ask_numpy/tell_numpy`` methods communicate numpy arrays containing the same data. + """ + + def __init__( + self, + variables: dict, + objectives: dict = {}, + History: npt.NDArray = [], + persis_info: dict = {}, + gen_specs: dict = {}, + libE_info: dict = {}, + **kwargs, + ): + self.variables = variables + self.objectives = objectives + self.History = History + self.gen_specs = gen_specs + self.libE_info = libE_info + + self.variables_mapping = kwargs.get("variables_mapping", {}) + + self._internal_variable = "x" # need to figure these out dynamically + self._internal_objective = "f" + + if self.variables: + + self.n = len(self.variables) + # build our own lb and ub + if "lb" not in kwargs and "ub" not in kwargs: + lb = [] + ub = [] + for i, v in enumerate(self.variables.values()): + if isinstance(v, list) and (isinstance(v[0], int) or isinstance(v[0], float)): + lb.append(v[0]) + ub.append(v[1]) + kwargs["lb"] = np.array(lb) + kwargs["ub"] = np.array(ub) + + if len(kwargs) > 0: # so user can specify gen-specific parameters as kwargs to constructor + if not self.gen_specs.get("user"): + self.gen_specs["user"] = {} + self.gen_specs["user"].update(kwargs) + if not persis_info.get("rand_stream"): + self.persis_info = add_unique_random_streams({}, 4, seed=4321)[1] + else: + self.persis_info = persis_info + + @abstractmethod + def ask_numpy(self, num_points: Optional[int] = 0) -> npt.NDArray: + """Request the next set of points to evaluate, as a NumPy array.""" + + @abstractmethod + def tell_numpy(self, results: npt.NDArray) -> None: + """Send the results, as a NumPy array, of evaluations to the generator.""" + + @staticmethod + def convert_np_types(dict_list): + return [ + {key: (value.item() if isinstance(value, np.generic) else value) for key, value in item.items()} + for item in dict_list + ] + + def ask(self, num_points: Optional[int] = 0) -> List[dict]: + """Request the next set of points to evaluate.""" + return LibensembleGenerator.convert_np_types( + np_to_list_dicts(self.ask_numpy(num_points), mapping=self.variables_mapping) + ) + + def tell(self, results: List[dict]) -> None: + """Send the results of evaluations to the generator.""" + self.tell_numpy(list_dicts_to_np(results, mapping=self.variables_mapping)) + + +class LibensembleGenThreadInterfacer(LibensembleGenerator): + """Implement ask/tell for traditionally written libEnsemble persistent generator functions. + Still requires a handful of libEnsemble-specific data-structures on initialization. + """ + + def __init__( + self, + variables: dict, + objectives: dict = {}, + History: npt.NDArray = [], + persis_info: dict = {}, + gen_specs: dict = {}, + libE_info: dict = {}, + **kwargs, + ) -> None: + super().__init__(variables, objectives, History, persis_info, gen_specs, libE_info, **kwargs) + self.gen_f = gen_specs["gen_f"] + self.History = History + self.libE_info = libE_info + self.thread = None + + def setup(self) -> None: + """Must be called once before calling ask/tell. Initializes the background thread.""" + if self.thread is not None: + return + # SH this contains the thread lock - removing.... wrong comm to pass on anyway. + if hasattr(Executor.executor, "comm"): + del Executor.executor.comm + self.libE_info["executor"] = Executor.executor + + self.thread = QCommProcess( # TRY A PROCESS + self.gen_f, + None, + self.History, + self.persis_info, + self.gen_specs, + self.libE_info, + user_function=True, + ) + + # SH this is a bit hacky - maybe it can be done inside comms (in _qcomm_main)? + self.libE_info["comm"] = self.thread.comm + + def _set_sim_ended(self, results: npt.NDArray) -> npt.NDArray: + new_results = np.zeros(len(results), dtype=self.gen_specs["out"] + [("sim_ended", bool), ("f", float)]) + for field in results.dtype.names: + try: + new_results[field] = results[field] + except ValueError: # lets not slot in data that the gen doesnt need? + continue + new_results["sim_ended"] = True + return new_results + + def tell(self, results: List[dict], tag: int = EVAL_GEN_TAG) -> None: + """Send the results of evaluations to the generator.""" + self.tell_numpy(list_dicts_to_np(results, mapping=self.variables_mapping), tag) + + def ask_numpy(self, num_points: int = 0) -> npt.NDArray: + """Request the next set of points to evaluate, as a NumPy array.""" + if self.thread is None: + self.setup() + self.thread.run() + _, ask_full = self.thread.recv() + return ask_full["calc_out"] + + def tell_numpy(self, results: npt.NDArray, tag: int = EVAL_GEN_TAG) -> None: + """Send the results of evaluations to the generator, as a NumPy array.""" + if results is not None: + results = self._set_sim_ended(results) + Work = {"libE_info": {"H_rows": np.copy(results["sim_id"]), "persistent": True, "executor": None}} + self.thread.send(tag, Work) + self.thread.send(tag, np.copy(results)) # SH for threads check - might need deepcopy due to dtype=object + else: + self.thread.send(tag, None) + + def final_tell(self, results: npt.NDArray = None) -> (npt.NDArray, dict, int): + """Send any last results to the generator, and it to close down.""" + self.tell_numpy(results, PERSIS_STOP) # conversion happens in tell + return self.thread.result() diff --git a/libensemble/libE.py b/libensemble/libE.py index 2762890bc..ec97ba9ed 100644 --- a/libensemble/libE.py +++ b/libensemble/libE.py @@ -281,7 +281,7 @@ def manager( logger.info(f"libE version v{__version__}") if "out" in gen_specs and ("sim_id", int) in gen_specs["out"]: - if "libensemble.gen_funcs" not in gen_specs["gen_f"].__module__: + if hasattr(gen_specs["gen_f"], "__module__") and "libensemble.gen_funcs" not in gen_specs["gen_f"].__module__: logger.manager_warning(_USER_SIM_ID_WARNING) try: @@ -459,6 +459,7 @@ def start_proc_team(nworkers, sim_specs, gen_specs, libE_specs, log_comm=True): for wcomm in wcomms: wcomm.run() + return wcomms diff --git a/libensemble/sim_funcs/borehole_kills.py b/libensemble/sim_funcs/borehole_kills.py index 54a31256b..47a00af90 100644 --- a/libensemble/sim_funcs/borehole_kills.py +++ b/libensemble/sim_funcs/borehole_kills.py @@ -5,7 +5,7 @@ from libensemble.sim_funcs.surmise_test_function import borehole_true -def subproc_borehole(H, delay): +def subproc_borehole(H, delay, poll_manager): """This evaluates the Borehole function using a subprocess running compiled code. @@ -15,14 +15,14 @@ def subproc_borehole(H, delay): """ with open("input", "w") as f: - H["thetas"][0].tofile(f) - H["x"][0].tofile(f) + H["thetas"].tofile(f) + H["x"].tofile(f) exctr = Executor.executor args = "input" + " " + str(delay) task = exctr.submit(app_name="borehole", app_args=args, stdout="out.txt", stderr="err.txt") - calc_status = exctr.polling_loop(task, delay=0.01, poll_manager=True) + calc_status = exctr.polling_loop(task, delay=0.01, poll_manager=poll_manager) if calc_status in MAN_KILL_SIGNALS + [TASK_FAILED]: f = np.inf @@ -45,7 +45,7 @@ def borehole(H, persis_info, sim_specs, libE_info): if sim_id > sim_specs["user"]["init_sample_size"]: delay = 2 + np.random.normal(scale=0.5) - f, calc_status = subproc_borehole(H, delay) + f, calc_status = subproc_borehole(H, delay, sim_specs["user"].get("poll_manager", True)) if calc_status in MAN_KILL_SIGNALS and "sim_killed" in H_o.dtype.names: H_o["sim_killed"] = True # For calling script to print only. diff --git a/libensemble/specs.py b/libensemble/specs.py index e8779f930..a1a5a718b 100644 --- a/libensemble/specs.py +++ b/libensemble/specs.py @@ -78,6 +78,11 @@ class GenSpecs(BaseModel): simulator function, and makes decisions based on simulator function output. """ + generator: Optional[object] = None + """ + A pre-initialized generator object. + """ + inputs: Optional[List[str]] = Field(default=[], alias="in") """ List of **field names** out of the complete history to pass @@ -105,6 +110,24 @@ class GenSpecs(BaseModel): calling them locally. """ + initial_batch_size: Optional[int] = 0 + """ + Number of initial points to request that the generator create. If zero, falls back to ``batch_size``. + If both options are zero, defaults to the number of workers. + + Note: Certain generators included with libEnsemble decide + batch sizes via ``gen_specs["user"]`` or other methods. + """ + + batch_size: Optional[int] = 0 + """ + Number of points to generate in each batch. If zero, falls back to the number of + completed evaluations most recently told to the generator. + + Note: Certain generators included with libEnsemble decide + batch sizes via ``gen_specs["user"]`` or other methods. + """ + threaded: Optional[bool] = False """ Instruct Worker process to launch user function to a thread. diff --git a/libensemble/tests/functionality_tests/test_asktell_sampling.py b/libensemble/tests/functionality_tests/test_asktell_sampling.py new file mode 100644 index 000000000..506118d5c --- /dev/null +++ b/libensemble/tests/functionality_tests/test_asktell_sampling.py @@ -0,0 +1,70 @@ +""" +Runs libEnsemble with Latin hypercube sampling on a simple 1D problem + +Execute via one of the following commands (e.g. 3 workers): + mpiexec -np 4 python test_sampling_asktell_gen.py + python test_sampling_asktell_gen.py --nworkers 3 --comms local + python test_sampling_asktell_gen.py --nworkers 3 --comms tcp + +The number of concurrent evaluations of the objective function will be 4-1=3. +""" + +# Do not change these lines - they are parsed by run-tests.sh +# TESTSUITE_COMMS: mpi local +# TESTSUITE_NPROCS: 2 4 + +import numpy as np + +# Import libEnsemble items for this test +from libensemble.alloc_funcs.start_only_persistent import only_persistent_gens as alloc_f +from libensemble.gen_classes.sampling import UniformSample +from libensemble.libE import libE +from libensemble.tools import add_unique_random_streams, parse_args + + +def sim_f(In): + Out = np.zeros(1, dtype=[("f", float)]) + Out["f"] = np.linalg.norm(In) + return Out + + +if __name__ == "__main__": + nworkers, is_manager, libE_specs, _ = parse_args() + libE_specs["gen_on_manager"] = True + + sim_specs = { + "sim_f": sim_f, + "in": ["x"], + "out": [("f", float), ("grad", float, 2)], + } + + gen_specs = { + "persis_in": ["x", "f", "grad", "sim_id"], + "out": [("x", float, (2,))], + "initial_batch_size": 20, + "batch_size": 10, + "user": { + "initial_batch_size": 20, # for wrapper + "lb": np.array([-3, -2]), + "ub": np.array([3, 2]), + }, + } + + variables = {"x0": [-3, 3], "x1": [-2, 2]} + + objectives = {"f": "EXPLORE"} + + alloc_specs = {"alloc_f": alloc_f} + exit_criteria = {"gen_max": 201} + + persis_info = add_unique_random_streams({}, nworkers + 1, seed=1234) + + # Using asktell runner - pass object + generator = UniformSample(variables, objectives) + gen_specs["generator"] = generator + + H, persis_info, flag = libE(sim_specs, gen_specs, exit_criteria, persis_info, alloc_specs, libE_specs=libE_specs) + + if is_manager: + print(H[["sim_id", "x", "f"]][:10]) + assert len(H) >= 201, f"H has length {len(H)}" diff --git a/libensemble/tests/regression_tests/test_asktell_aposmm_nlopt.py b/libensemble/tests/regression_tests/test_asktell_aposmm_nlopt.py new file mode 100644 index 000000000..25fbc6afb --- /dev/null +++ b/libensemble/tests/regression_tests/test_asktell_aposmm_nlopt.py @@ -0,0 +1,90 @@ +""" +Runs libEnsemble with APOSMM with the NLopt local optimizer. + +Execute via one of the following commands (e.g. 3 workers): + mpiexec -np 4 python test_persistent_aposmm_nlopt.py + python test_persistent_aposmm_nlopt.py --nworkers 3 --comms local + python test_persistent_aposmm_nlopt.py --nworkers 3 --comms tcp + +When running with the above commands, the number of concurrent evaluations of +the objective function will be 2, as one of the three workers will be the +persistent generator. +""" + +# Do not change these lines - they are parsed by run-tests.sh +# TESTSUITE_COMMS: local mpi tcp +# TESTSUITE_NPROCS: 3 + +import sys +from math import gamma, pi, sqrt + +import numpy as np + +import libensemble.gen_funcs + +# Import libEnsemble items for this test +from libensemble.sim_funcs.six_hump_camel import six_hump_camel as sim_f + +libensemble.gen_funcs.rc.aposmm_optimizers = "nlopt" +from time import time + +from libensemble import Ensemble +from libensemble.alloc_funcs.persistent_aposmm_alloc import persistent_aposmm_alloc as alloc_f +from libensemble.gen_classes import APOSMM +from libensemble.specs import AllocSpecs, ExitCriteria, GenSpecs, SimSpecs +from libensemble.tests.regression_tests.support import six_hump_camel_minima as minima + +# Main block is necessary only when using local comms with spawn start method (default on macOS and Windows). +if __name__ == "__main__": + + workflow = Ensemble(parse_args=True) + + if workflow.is_manager: + start_time = time() + + if workflow.nworkers < 2: + sys.exit("Cannot run with a persistent worker if only one worker -- aborting...") + + n = 2 + workflow.sim_specs = SimSpecs(sim_f=sim_f, inputs=["x"], outputs=[("f", float)]) + workflow.alloc_specs = AllocSpecs(alloc_f=alloc_f) + workflow.exit_criteria = ExitCriteria(sim_max=2000) + + aposmm = APOSMM( + variables={"x0": [-3, 3], "x1": [-2, 2]}, # we hope to combine these + objectives={"f": "MINIMIZE"}, + initial_sample_size=100, + sample_points=minima, + localopt_method="LN_BOBYQA", + rk_const=0.5 * ((gamma(1 + (n / 2)) * 5) ** (1 / n)) / sqrt(pi), + xtol_abs=1e-6, + ftol_abs=1e-6, + max_active_runs=workflow.nworkers, # should this match nworkers always? practically? + variables_mapping={"x": ["x0", "x1"]}, + ) + + workflow.gen_specs = GenSpecs( + persis_in=["x", "x_on_cube", "sim_id", "local_min", "local_pt", "f"], + generator=aposmm, + batch_size=5, + initial_batch_size=10, + user={"initial_sample_size": 100}, + ) + + workflow.libE_specs.gen_on_manager = True + workflow.add_random_streams() + + H, _, _ = workflow.run() + + # Perform the run + + if workflow.is_manager: + print("[Manager]:", H[np.where(H["local_min"])]["x"]) + print("[Manager]: Time taken =", time() - start_time, flush=True) + + tol = 1e-5 + for m in minima: + # The minima are known on this test problem. + # We use their values to test APOSMM has identified all minima + print(np.min(np.sum((H[H["local_min"]]["x"] - m) ** 2, 1)), flush=True) + assert np.min(np.sum((H[H["local_min"]]["x"] - m) ** 2, 1)) < tol diff --git a/libensemble/tests/regression_tests/test_asktell_gpCAM.py b/libensemble/tests/regression_tests/test_asktell_gpCAM.py new file mode 100644 index 000000000..1c8e2559c --- /dev/null +++ b/libensemble/tests/regression_tests/test_asktell_gpCAM.py @@ -0,0 +1,98 @@ +""" +Tests libEnsemble with gpCAM + +Execute via one of the following commands (e.g. 3 workers): + mpiexec -np 4 python test_gpCAM_class.py + python test_gpCAM_class.py --nworkers 3 --comms local + +When running with the above commands, the number of concurrent evaluations of +the objective function will be 2, as one of the three workers will be the +persistent generator. + +See libensemble.gen_funcs.persistent_gpCAM for more details about the generator +setup. +""" + +# Do not change these lines - they are parsed by run-tests.sh +# TESTSUITE_COMMS: mpi local +# TESTSUITE_NPROCS: 4 +# TESTSUITE_EXTRA: true +# TESTSUITE_EXCLUDE: true + +import sys +import warnings + +import numpy as np + +from libensemble.alloc_funcs.start_only_persistent import only_persistent_gens as alloc_f +from libensemble.gen_classes.gpCAM import GP_CAM, GP_CAM_Covar + +# Import libEnsemble items for this test +from libensemble.libE import libE +from libensemble.sim_funcs.rosenbrock import rosenbrock_eval as sim_f +from libensemble.tools import add_unique_random_streams, parse_args, save_libE_output + +warnings.filterwarnings("ignore", message="Default hyperparameter_bounds") + + +# Main block is necessary only when using local comms with spawn start method (default on macOS and Windows). +if __name__ == "__main__": + nworkers, is_manager, libE_specs, _ = parse_args() + + if nworkers < 2: + sys.exit("Cannot run with a persistent worker if only one worker -- aborting...") + + n = 4 + batch_size = 15 + + sim_specs = { + "sim_f": sim_f, + "in": ["x"], + "out": [ + ("f", float), + ], + } + + gen_specs = { + "persis_in": ["x", "f", "sim_id"], + "out": [("x", float, (n,))], + "user": { + "batch_size": batch_size, + "lb": np.array([-3, -2, -1, -1]), + "ub": np.array([3, 2, 1, 1]), + }, + } + + alloc_specs = {"alloc_f": alloc_f} + + persis_info = add_unique_random_streams({}, nworkers + 1) + + gen = GP_CAM_Covar(None, persis_info[1], gen_specs, None) + + for inst in range(3): + if inst == 0: + gen_specs["generator"] = gen + num_batches = 10 + exit_criteria = {"sim_max": num_batches * batch_size, "wallclock_max": 300} + libE_specs["save_every_k_gens"] = 150 + libE_specs["H_file_prefix"] = "gpCAM_nongrid" + if inst == 1: + gen_specs["user"]["use_grid"] = True + gen_specs["user"]["test_points_file"] = "gpCAM_nongrid_after_gen_150.npy" + libE_specs["final_gen_send"] = True + del libE_specs["H_file_prefix"] + del libE_specs["save_every_k_gens"] + elif inst == 2: + persis_info = add_unique_random_streams({}, nworkers + 1) + gen_specs["generator"] = GP_CAM(None, persis_info[1], gen_specs, None) + num_batches = 3 # Few because the ask_tell gen can be slow + gen_specs["user"]["ask_max_iter"] = 1 # For quicker test + exit_criteria = {"sim_max": num_batches * batch_size, "wallclock_max": 300} + + # Perform the run + H, persis_info, flag = libE(sim_specs, gen_specs, exit_criteria, persis_info, alloc_specs, libE_specs) + + if is_manager: + assert len(np.unique(H["gen_ended_time"])) == num_batches + + save_libE_output(H, persis_info, __file__, nworkers) diff --git a/libensemble/tests/regression_tests/test_gpCAM.py b/libensemble/tests/regression_tests/test_gpCAM.py index e1bc1e404..9e87211e8 100644 --- a/libensemble/tests/regression_tests/test_gpCAM.py +++ b/libensemble/tests/regression_tests/test_gpCAM.py @@ -34,11 +34,10 @@ from libensemble.sim_funcs.rosenbrock import rosenbrock_eval as sim_f from libensemble.tools import add_unique_random_streams, parse_args, save_libE_output -# Main block is necessary only when using local comms with spawn start method (default on macOS and Windows). - warnings.filterwarnings("ignore", message="Default hyperparameter_bounds") +# Main block is necessary only when using local comms with spawn start method (default on macOS and Windows). if __name__ == "__main__": nworkers, is_manager, libE_specs, _ = parse_args() diff --git a/libensemble/tests/regression_tests/test_persistent_aposmm_nlopt.py b/libensemble/tests/regression_tests/test_persistent_aposmm_nlopt.py index 681133016..2bcd7bf6b 100644 --- a/libensemble/tests/regression_tests/test_persistent_aposmm_nlopt.py +++ b/libensemble/tests/regression_tests/test_persistent_aposmm_nlopt.py @@ -79,7 +79,7 @@ alloc_specs = {"alloc_f": alloc_f} - persis_info = add_unique_random_streams({}, nworkers + 1) + persis_info = add_unique_random_streams({}, nworkers + 1, seed=4321) exit_criteria = {"sim_max": 2000} diff --git a/libensemble/tests/unit_tests/test_asktell.py b/libensemble/tests/unit_tests/test_asktell.py new file mode 100644 index 000000000..1364b7031 --- /dev/null +++ b/libensemble/tests/unit_tests/test_asktell.py @@ -0,0 +1,138 @@ +import numpy as np + +from libensemble.utils.misc import list_dicts_to_np + + +def _check_conversion(H, npp, mapping={}): + + for field in H.dtype.names: + print(f"Comparing {field}: {H[field]} {npp[field]}") + + if isinstance(H[field], np.ndarray): + assert np.array_equal(H[field], npp[field]), f"Mismatch found in field {field}" + + elif isinstance(H[field], str) and isinstance(npp[field], str): + assert H[field] == npp[field], f"Mismatch found in field {field}" + + elif np.isscalar(H[field]) and np.isscalar(npp[field]): + assert np.isclose(H[field], npp[field]), f"Mismatch found in field {field}" + + else: + raise TypeError(f"Unhandled or mismatched types in field {field}: {type(H[field])} vs {type(npp[field])}") + + +def test_asktell_sampling_and_utils(): + from libensemble.gen_classes.sampling import UniformSample + + variables = {"x0": [-3, 3], "x1": [-2, 2]} + objectives = {"f": "EXPLORE"} + + # Test initialization with libensembley parameters + gen = UniformSample(variables, objectives) + assert len(gen.ask(10)) == 10 + + out_np = gen.ask_numpy(3) # should get numpy arrays, non-flattened + out = gen.ask(3) # needs to get dicts, 2d+ arrays need to be flattened + + assert all([len(x) == 2 for x in out]) # np_to_list_dicts is now tested + + # now we test list_dicts_to_np directly + out_np = list_dicts_to_np(out) + + # check combined values resemble flattened list-of-dicts values + assert out_np.dtype.names == ("x",) + for i, entry in enumerate(out): + for j, value in enumerate(entry.values()): + assert value == out_np["x"][i][j] + + variables = {"core": [-3, 3], "edge": [-2, 2]} + objectives = {"energy": "EXPLORE"} + mapping = {"x": ["core", "edge"]} + + gen = UniformSample(variables, objectives, mapping) + out = gen.ask(1) + assert len(out) == 1 + assert out[0].get("core") + assert out[0].get("edge") + + out_np = list_dicts_to_np(out, mapping=mapping) + assert out_np.dtype.names[0] == "x" + + +def test_awkward_list_dict(): + from libensemble.utils.misc import list_dicts_to_np + + # test list_dicts_to_np on a weirdly formatted dictionary + # Unfortunately, we're not really checking against some original + # libE-styled source of truth, like H. + + weird_list_dict = [ + { + "x0": "abcd", + "x1": "efgh", + "y": 56, + "z0": 1, + "z1": 2, + "z2": 3, + "z3": 4, + "z4": 5, + "z5": 6, + "z6": 7, + "z7": 8, + "z8": 9, + "z9": 10, + "z10": 11, + "a0": "B", + } + ] + + out_np = list_dicts_to_np(weird_list_dict) + + assert all([i in ("x", "y", "z", "a0") for i in out_np.dtype.names]) + + weird_list_dict = [ + { + "sim_id": 77, + "core": 89, + "edge": 10.1, + "beam": 76.5, + "energy": 12.34, + "local_pt": True, + "local_min": False, + }, + { + "sim_id": 10, + "core": 32.8, + "edge": 16.2, + "beam": 33.5, + "energy": 99.34, + "local_pt": False, + "local_min": False, + }, + ] + + # target dtype: [("sim_id", int), ("x, float, (3,)), ("f", float), ("local_pt", bool), ("local_min", bool)] + + mapping = {"x": ["core", "edge", "beam"], "f": ["energy"]} + out_np = list_dicts_to_np(weird_list_dict, mapping=mapping) + + assert all([i in ("sim_id", "x", "f", "local_pt", "local_min") for i in out_np.dtype.names]) + + +def test_awkward_H(): + from libensemble.utils.misc import list_dicts_to_np, np_to_list_dicts + + dtype = [("a", "i4"), ("x", "f4", (3,)), ("y", "f4", (1,)), ("z", "f4", (12,)), ("greeting", "U10"), ("co2", "f8")] + H = np.zeros(2, dtype=dtype) + H[0] = (1, [1.1, 2.2, 3.3], [10.1], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "hello", "1.23") + H[1] = (2, [4.4, 5.5, 6.6], [11.1], [51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62], "goodbye", "2.23") + + list_dicts = np_to_list_dicts(H) + npp = list_dicts_to_np(list_dicts, dtype=dtype) + _check_conversion(H, npp) + + +if __name__ == "__main__": + test_asktell_sampling_and_utils() + test_awkward_list_dict() + test_awkward_H() diff --git a/libensemble/tests/unit_tests/RENAME_test_persistent_aposmm.py b/libensemble/tests/unit_tests/test_persistent_aposmm.py similarity index 69% rename from libensemble/tests/unit_tests/RENAME_test_persistent_aposmm.py rename to libensemble/tests/unit_tests/test_persistent_aposmm.py index b08bc85fa..25ecdfd46 100644 --- a/libensemble/tests/unit_tests/RENAME_test_persistent_aposmm.py +++ b/libensemble/tests/unit_tests/test_persistent_aposmm.py @@ -168,8 +168,86 @@ def test_standalone_persistent_aposmm_combined_func(): assert persis_info.get("run_order"), "Standalone persistent_aposmm didn't do any localopt runs" +@pytest.mark.extra +def test_asktell_with_persistent_aposmm(): + from math import gamma, pi, sqrt + + import libensemble.gen_funcs + from libensemble.gen_classes import APOSMM + from libensemble.message_numbers import FINISHED_PERSISTENT_GEN_TAG + from libensemble.sim_funcs.six_hump_camel import six_hump_camel_func + from libensemble.tests.regression_tests.support import six_hump_camel_minima as minima + + libensemble.gen_funcs.rc.aposmm_optimizers = "nlopt" + + n = 2 + eval_max = 2000 + + gen_specs = { + "user": { + "initial_sample_size": 100, + "sample_points": np.round(minima, 1), + "localopt_method": "LN_BOBYQA", + "rk_const": 0.5 * ((gamma(1 + (n / 2)) * 5) ** (1 / n)) / sqrt(pi), + "xtol_abs": 1e-6, + "ftol_abs": 1e-6, + "dist_to_bound_multiple": 0.5, + "max_active_runs": 6, + }, + } + + variables = {"core": [-3, 3], "edge": [-2, 2]} + objectives = {"energy": "MINIMIZE"} + variables_mapping = {"x": ["core", "edge"], "f": ["energy"]} + + my_APOSMM = APOSMM( + variables=variables, objectives=objectives, gen_specs=gen_specs, variables_mapping=variables_mapping + ) + + initial_sample = my_APOSMM.ask(100) + + total_evals = 0 + eval_max = 2000 + + for point in initial_sample: + point["energy"] = six_hump_camel_func(np.array([point["core"], point["edge"]])) + total_evals += 1 + + my_APOSMM.tell(initial_sample) + + potential_minima = [] + + while total_evals < eval_max: + + sample, detected_minima = my_APOSMM.ask(6), my_APOSMM.ask_updates() + if len(detected_minima): + for m in detected_minima: + potential_minima.append(m) + for point in sample: + point["energy"] = six_hump_camel_func(np.array([point["core"], point["edge"]])) + total_evals += 1 + my_APOSMM.tell(sample) + H, persis_info, exit_code = my_APOSMM.final_tell() + + assert exit_code == FINISHED_PERSISTENT_GEN_TAG, "Standalone persistent_aposmm didn't exit correctly" + assert persis_info.get("run_order"), "Standalone persistent_aposmm didn't do any localopt runs" + + assert len(potential_minima) >= 6, f"Found {len(potential_minima)} minima" + + tol = 1e-3 + min_found = 0 + for m in minima: + # The minima are known on this test problem. + # We use their values to test APOSMM has identified all minima + print(np.min(np.sum((H[H["local_min"]]["x"] - m) ** 2, 1)), flush=True) + if np.min(np.sum((H[H["local_min"]]["x"] - m) ** 2, 1)) < tol: + min_found += 1 + assert min_found >= 6, f"Found {min_found} minima" + + if __name__ == "__main__": test_persis_aposmm_localopt_test() test_update_history_optimal() test_standalone_persistent_aposmm() test_standalone_persistent_aposmm_combined_func() + test_asktell_with_persistent_aposmm() diff --git a/libensemble/tests/unit_tests/test_ufunc_runners.py b/libensemble/tests/unit_tests/test_ufunc_runners.py index 1d3cbb4b2..51aa8c65d 100644 --- a/libensemble/tests/unit_tests/test_ufunc_runners.py +++ b/libensemble/tests/unit_tests/test_ufunc_runners.py @@ -30,8 +30,8 @@ def get_ufunc_args(): def test_normal_runners(): calc_in, sim_specs, gen_specs = get_ufunc_args() - simrunner = Runner(sim_specs) - genrunner = Runner(gen_specs) + simrunner = Runner.from_specs(sim_specs) + genrunner = Runner.from_specs(gen_specs) assert not hasattr(simrunner, "globus_compute_executor") and not hasattr( genrunner, "globus_compute_executor" ), "Globus Compute use should not be detected without setting endpoint fields" @@ -47,7 +47,7 @@ def tupilize(arg1, arg2): sim_specs["sim_f"] = tupilize persis_info = {"hello": "threads"} - simrunner = Runner(sim_specs) + simrunner = Runner.from_specs(sim_specs) result = simrunner._result(calc_in, persis_info, {}) assert result == (calc_in, persis_info) assert hasattr(simrunner, "thread_handle") @@ -61,7 +61,7 @@ def test_globus_compute_runner_init(): sim_specs["globus_compute_endpoint"] = "1234" with mock.patch("globus_compute_sdk.Executor"): - runner = Runner(sim_specs) + runner = Runner.from_specs(sim_specs) assert hasattr( runner, "globus_compute_executor" @@ -75,7 +75,7 @@ def test_globus_compute_runner_pass(): sim_specs["globus_compute_endpoint"] = "1234" with mock.patch("globus_compute_sdk.Executor"): - runner = Runner(sim_specs) + runner = Runner.from_specs(sim_specs) # Creating Mock Globus ComputeExecutor and Globus Compute future object - no exception globus_compute_mock = mock.Mock() @@ -101,7 +101,7 @@ def test_globus_compute_runner_fail(): gen_specs["globus_compute_endpoint"] = "4321" with mock.patch("globus_compute_sdk.Executor"): - runner = Runner(gen_specs) + runner = Runner.from_specs(gen_specs) # Creating Mock Globus ComputeExecutor and Globus Compute future object - yes exception globus_compute_mock = mock.Mock() diff --git a/libensemble/tools/alloc_support.py b/libensemble/tools/alloc_support.py index 0d4ce91d8..9e2fb5d8c 100644 --- a/libensemble/tools/alloc_support.py +++ b/libensemble/tools/alloc_support.py @@ -280,6 +280,7 @@ def gen_work(self, wid, H_fields, H_rows, persis_info, **libE_info): H_fields = AllocSupport._check_H_fields(H_fields) libE_info["H_rows"] = AllocSupport._check_H_rows(H_rows) + libE_info["batch_size"] = len(self.avail_worker_ids(gen_workers=False)) work = { "H_fields": H_fields, diff --git a/libensemble/utils/misc.py b/libensemble/utils/misc.py index ca67095ac..7cc9c1a2a 100644 --- a/libensemble/utils/misc.py +++ b/libensemble/utils/misc.py @@ -2,10 +2,13 @@ Misc internal functions """ -from itertools import groupby +from itertools import chain, groupby from operator import itemgetter +from typing import List +import numpy as np import pydantic +from numpy import typing as npt pydantic_version = pydantic.__version__[0] @@ -76,3 +79,121 @@ def specs_checker_setattr(obj, key, value): obj[key] = value else: # actual obj obj.__dict__[key] = value + + +def _decide_dtype(name: str, entry, size: int) -> tuple: + if isinstance(entry, str): + output_type = "U" + str(len(entry) + 1) + else: + output_type = type(entry) + if size == 1 or not size: + return (name, output_type) + else: + return (name, output_type, (size,)) + + +def _combine_names(names: list) -> list: + """combine fields with same name *except* for final digits""" + + out_names = [] + stripped = list(i.rstrip("0123456789") for i in names) # ['x', 'x', y', 'z', 'a'] + for name in names: + stripped_name = name.rstrip("0123456789") + if stripped.count(stripped_name) > 1: # if name appears >= 1, will combine, don't keep int suffix + out_names.append(stripped_name) + else: + out_names.append(name) # name appears once, keep integer suffix, e.g. "co2" + + # intending [x, y, z, a0] from [x0, x1, y, z0, z1, z2, z3, a0] + return list(set(out_names)) + + +def list_dicts_to_np(list_dicts: list, dtype: list = None, mapping: dict = {}) -> npt.NDArray: + if list_dicts is None: + return None + + if not isinstance(list_dicts, list): # presumably already a numpy array, conversion not necessary + return list_dicts + + for entry in list_dicts: + if "_id" in entry: + entry["sim_id"] = entry.pop("_id") + + if dtype is None: + dtype = [] + + # build a presumptive dtype + + first = list_dicts[0] # for determining dtype of output np array + new_dtype_names = _combine_names([i for i in first.keys()]) # -> ['x', 'y'] + fields_to_convert = list(chain.from_iterable(list(mapping.values()))) + new_dtype_names = [i for i in new_dtype_names if i not in fields_to_convert] + list(mapping.keys()) + combinable_names = [] # [['x0', 'x1'], ['y0', 'y1', 'y2'], ['z']] + for name in new_dtype_names: + combinable_group = [i for i in first.keys() if i.rstrip("0123456789") == name] + if len(combinable_group) > 1: # multiple similar names, e.g. x0, x1 + combinable_names.append(combinable_group) + else: # single name, e.g. local_pt, a0 *AS LONG AS THERE ISNT AN A1* + combinable_names.append([name]) + + # build dtype of non-mapped fields + if not len(dtype): + for i, entry in enumerate(combinable_names): + name = new_dtype_names[i] + size = len(combinable_names[i]) + if name not in mapping: + dtype.append(_decide_dtype(name, first[entry[0]], size)) + + # append dtype of mapped float fields + if len(mapping): + for name in mapping: + size = len(mapping[name]) + dtype.append(_decide_dtype(name, 0.0, size)) # float + + out = np.zeros(len(list_dicts), dtype=dtype) + + for j, input_dict in enumerate(list_dicts): + for output_name, field_names in zip(new_dtype_names, combinable_names): + if output_name not in mapping: + out[output_name][j] = ( + tuple(input_dict[name] for name in field_names) + if len(field_names) > 1 + else input_dict[field_names[0]] + ) + else: + out[output_name][j] = ( + tuple(input_dict[name] for name in mapping[output_name]) + if len(mapping[output_name]) > 1 + else input_dict[mapping[output_name][0]] + ) + + return out + + +def np_to_list_dicts(array: npt.NDArray, mapping: dict = {}) -> List[dict]: + if array is None: + return None + out = [] + for row in array: + new_dict = {} + for field in row.dtype.names: + # non-string arrays, lists, etc. + if field not in list(mapping.keys()): + if hasattr(row[field], "__len__") and len(row[field]) > 1 and not isinstance(row[field], str): + for i, x in enumerate(row[field]): + new_dict[field + str(i)] = x + elif hasattr(row[field], "__len__") and len(row[field]) == 1: # single-entry arrays, lists, etc. + new_dict[field] = row[field][0] # will still work on single-char strings + else: + new_dict[field] = row[field] + else: + assert array.dtype[field].shape[0] == len(mapping[field]), "unable to unpack multidimensional array" + for i, name in enumerate(mapping[field]): + new_dict[name] = row[field][i] + out.append(new_dict) + + for entry in out: + if "sim_id" in entry: + entry["_id"] = entry.pop("sim_id") + + return out diff --git a/libensemble/utils/pydantic_bindings.py b/libensemble/utils/pydantic_bindings.py index 7ceca9615..5c1f6e17d 100644 --- a/libensemble/utils/pydantic_bindings.py +++ b/libensemble/utils/pydantic_bindings.py @@ -5,7 +5,7 @@ from libensemble import specs from libensemble.resources import platforms from libensemble.utils.misc import pydanticV1 -from libensemble.utils.validators import ( +from libensemble.utils.validators import ( # check_output_fields, _UFUNC_INVALID_ERR, _UNRECOGNIZED_ERR, check_any_workers_and_disable_rm_if_tcp, @@ -16,8 +16,8 @@ check_inputs_exist, check_logical_cores, check_mpi_runner_type, - check_output_fields, check_provided_ufuncs, + check_set_gen_specs_from_variables, check_valid_comms_type, check_valid_in, check_valid_out, @@ -104,6 +104,7 @@ class Config: __validators__={ "check_valid_out": check_valid_out, "check_valid_in": check_valid_in, + "check_set_gen_specs_from_variables": check_set_gen_specs_from_variables, "genf_set_in_out_from_attrs": genf_set_in_out_from_attrs, }, ) @@ -129,7 +130,6 @@ class Config: __base__=specs._EnsembleSpecs, __validators__={ "check_exit_criteria": check_exit_criteria, - "check_output_fields": check_output_fields, "check_H0": check_H0, "check_provided_ufuncs": check_provided_ufuncs, }, diff --git a/libensemble/utils/runners.py b/libensemble/utils/runners.py index 629c733b1..eea0cfcf7 100644 --- a/libensemble/utils/runners.py +++ b/libensemble/utils/runners.py @@ -1,23 +1,36 @@ import inspect import logging import logging.handlers +import time from typing import Optional import numpy.typing as npt from libensemble.comms.comms import QCommThread +from libensemble.generators import LibensembleGenerator, LibensembleGenThreadInterfacer +from libensemble.message_numbers import EVAL_GEN_TAG, FINISHED_PERSISTENT_GEN_TAG, PERSIS_STOP, STOP_TAG +from libensemble.tools.persistent_support import PersistentSupport +from libensemble.utils.misc import list_dicts_to_np, np_to_list_dicts logger = logging.getLogger(__name__) class Runner: - def __new__(cls, specs): + @classmethod + def from_specs(cls, specs): if len(specs.get("globus_compute_endpoint", "")) > 0: - return super(Runner, GlobusComputeRunner).__new__(GlobusComputeRunner) - if specs.get("threaded"): # TODO: undecided interface - return super(Runner, ThreadRunner).__new__(ThreadRunner) + return GlobusComputeRunner(specs) + if specs.get("threaded"): + return ThreadRunner(specs) + if (generator := specs.get("generator")) is not None: + if isinstance(generator, LibensembleGenThreadInterfacer): + return LibensembleGenThreadRunner(specs) + if isinstance(generator, LibensembleGenerator): + return LibensembleGenRunner(specs) + else: + return AskTellGenRunner(specs) else: - return super().__new__(Runner) + return Runner(specs) def __init__(self, specs): self.specs = specs @@ -84,3 +97,103 @@ def _result(self, calc_in: npt.NDArray, persis_info: dict, libE_info: dict) -> ( def shutdown(self) -> None: if self.thread_handle is not None: self.thread_handle.terminate() + + +class AskTellGenRunner(Runner): + """Interact with ask/tell generator. Base class initialized for third-party generators.""" + + def __init__(self, specs): + super().__init__(specs) + self.gen = specs.get("generator") + + def _get_points_updates(self, batch_size: int) -> (npt.NDArray, npt.NDArray): + # no ask_updates on external gens + return ( + list_dicts_to_np(self.gen.ask(batch_size), dtype=self.specs.get("out"), mapping=self.gen.variables_mapping), + None, + ) + + def _convert_tell(self, x: npt.NDArray) -> list: + self.gen.tell(np_to_list_dicts(x)) + + def _loop_over_gen(self, tag, Work, H_in): + """Interact with ask/tell generator that *does not* contain a background thread""" + while tag not in [PERSIS_STOP, STOP_TAG]: + batch_size = self.specs.get("batch_size") or len(H_in) + H_out, _ = self._get_points_updates(batch_size) + tag, Work, H_in = self.ps.send_recv(H_out) + self._convert_tell(H_in) + return H_in + + def _get_initial_ask(self, libE_info) -> npt.NDArray: + """Get initial batch from generator based on generator type""" + initial_batch = self.specs.get("initial_batch_size") or self.specs.get("batch_size") or libE_info["batch_size"] + H_out = self.gen.ask(initial_batch) + return H_out + + def _start_generator_loop(self, tag, Work, H_in): + """Start the generator loop after choosing best way of giving initial results to gen""" + self.gen.tell(np_to_list_dicts(H_in)) + return self._loop_over_gen(tag, Work, H_in) + + def _persistent_result(self, calc_in, persis_info, libE_info): + """Setup comms with manager, setup gen, loop gen to completion, return gen's results""" + self.ps = PersistentSupport(libE_info, EVAL_GEN_TAG) + # libE gens will hit the following line, but list_dicts_to_np will passthrough if the output is a numpy array + H_out = list_dicts_to_np( + self._get_initial_ask(libE_info), dtype=self.specs.get("out"), mapping=self.gen.variables_mapping + ) + tag, Work, H_in = self.ps.send_recv(H_out) # evaluate the initial sample + final_H_in = self._start_generator_loop(tag, Work, H_in) + return self.gen.final_tell(final_H_in), FINISHED_PERSISTENT_GEN_TAG + + def _result(self, calc_in: npt.NDArray, persis_info: dict, libE_info: dict) -> (npt.NDArray, dict, Optional[int]): + if libE_info.get("persistent"): + return self._persistent_result(calc_in, persis_info, libE_info) + raise ValueError("ask/tell generators must run in persistent mode. This may be the default in the future.") + + +class LibensembleGenRunner(AskTellGenRunner): + def _get_initial_ask(self, libE_info) -> npt.NDArray: + """Get initial batch from generator based on generator type""" + H_out = self.gen.ask_numpy(libE_info["batch_size"]) # OR GEN SPECS INITIAL BATCH SIZE + return H_out + + def _get_points_updates(self, batch_size: int) -> (npt.NDArray, list): + return self.gen.ask_numpy(batch_size), self.gen.ask_updates() + + def _convert_tell(self, x: npt.NDArray) -> list: + self.gen.tell_numpy(x) + + def _start_generator_loop(self, tag, Work, H_in) -> npt.NDArray: + """Start the generator loop after choosing best way of giving initial results to gen""" + self.gen.tell_numpy(H_in) + return self._loop_over_gen(tag, Work, H_in) # see parent class + + +class LibensembleGenThreadRunner(AskTellGenRunner): + def _get_initial_ask(self, libE_info) -> npt.NDArray: + """Get initial batch from generator based on generator type""" + return self.gen.ask_numpy() # libE really needs to receive the *entire* initial batch from a threaded gen + + def _ask_and_send(self): + """Loop over generator's outbox contents, send to manager""" + while not self.gen.thread.outbox.empty(): # recv/send any outstanding messages + points, updates = self.gen.ask_numpy(), self.gen.ask_updates() + if updates is not None and len(updates): + self.ps.send(points) + for i in updates: + self.ps.send(i, keep_state=True) # keep_state since an update doesn't imply "new points" + else: + self.ps.send(points) + + def _loop_over_gen(self, *args): + """Cycle between moving all outbound / inbound messages between threaded gen and manager""" + while True: + time.sleep(0.0025) # dont need to ping the gen relentlessly. Let it calculate. 400hz + self._ask_and_send() + while self.ps.comm.mail_flag(): # receive any new messages from Manager, give all to gen + tag, _, H_in = self.ps.recv() + if tag in [STOP_TAG, PERSIS_STOP]: + return H_in # this will get inserted into final_tell. this breaks loop + self.gen.tell_numpy(H_in) diff --git a/libensemble/utils/specs_checkers.py b/libensemble/utils/specs_checkers.py index cf33d359f..b8e793fa5 100644 --- a/libensemble/utils/specs_checkers.py +++ b/libensemble/utils/specs_checkers.py @@ -25,28 +25,10 @@ def _check_exit_criteria(values): return values -def _check_output_fields(values): - out_names = [e[0] for e in libE_fields] - if scg(values, "H0") is not None and scg(values, "H0").dtype.names is not None: - out_names += list(scg(values, "H0").dtype.names) - out_names += [e[0] for e in scg(values, "sim_specs").outputs] - if scg(values, "gen_specs"): - out_names += [e[0] for e in scg(values, "gen_specs").outputs] - if scg(values, "alloc_specs"): - out_names += [e[0] for e in scg(values, "alloc_specs").outputs] - - for name in scg(values, "sim_specs").inputs: - assert name in out_names, ( - name + " in sim_specs['in'] is not in sim_specs['out'], " - "gen_specs['out'], alloc_specs['out'], H0, or libE_fields." - ) - - if scg(values, "gen_specs"): - for name in scg(values, "gen_specs").inputs: - assert name in out_names, ( - name + " in gen_specs['in'] is not in sim_specs['out'], " - "gen_specs['out'], alloc_specs['out'], H0, or libE_fields." - ) +def _check_set_gen_specs_from_variables(values): + if not len(scg(values, "outputs")): + if scg(values, "generator") and len(scg(values, "generator").gen_specs["out"]): + scs(values, "outputs", scg(values, "generator").gen_specs["out"]) return values diff --git a/libensemble/utils/validators.py b/libensemble/utils/validators.py index e91d06a17..6cd100f4d 100644 --- a/libensemble/utils/validators.py +++ b/libensemble/utils/validators.py @@ -6,13 +6,13 @@ from libensemble.resources.platforms import Platform from libensemble.utils.misc import pydanticV1 -from libensemble.utils.specs_checkers import ( +from libensemble.utils.specs_checkers import ( # _check_output_fields, _check_any_workers_and_disable_rm_if_tcp, _check_exit_criteria, _check_H0, _check_logical_cores, - _check_output_fields, _check_set_calc_dirs_on_input_dir, + _check_set_gen_specs_from_variables, _check_set_workflow_dir, ) @@ -148,8 +148,8 @@ def check_exit_criteria(cls, values): return _check_exit_criteria(values) @root_validator - def check_output_fields(cls, values): - return _check_output_fields(values) + def check_set_gen_specs_from_variables(cls, values): + return _check_set_gen_specs_from_variables(values) @root_validator def check_H0(cls, values): @@ -158,13 +158,12 @@ def check_H0(cls, values): @root_validator def check_provided_ufuncs(cls, values): sim_specs = values.get("sim_specs") - assert hasattr(sim_specs, "sim_f"), "Simulation function not provided to SimSpecs." assert isinstance(sim_specs.sim_f, Callable), "Simulation function is not callable." if values.get("alloc_specs").alloc_f.__name__ != "give_pregenerated_sim_work": gen_specs = values.get("gen_specs") - assert hasattr(gen_specs, "gen_f"), "Generator function not provided to GenSpecs." - assert isinstance(gen_specs.gen_f, Callable), "Generator function is not callable." + if gen_specs.gen_f is not None: + assert isinstance(gen_specs.gen_f, Callable), "Generator function is not callable." return values @@ -247,8 +246,8 @@ def check_exit_criteria(self): return _check_exit_criteria(self) @model_validator(mode="after") - def check_output_fields(self): - return _check_output_fields(self) + def check_set_gen_specs_from_variables(self): + return _check_set_gen_specs_from_variables(self) @model_validator(mode="after") def check_H0(self): @@ -256,12 +255,11 @@ def check_H0(self): @model_validator(mode="after") def check_provided_ufuncs(self): - assert hasattr(self.sim_specs, "sim_f"), "Simulation function not provided to SimSpecs." assert isinstance(self.sim_specs.sim_f, Callable), "Simulation function is not callable." if self.alloc_specs.alloc_f.__name__ != "give_pregenerated_sim_work": - assert hasattr(self.gen_specs, "gen_f"), "Generator function not provided to GenSpecs." - assert isinstance(self.gen_specs.gen_f, Callable), "Generator function is not callable." + if self.gen_specs.gen_f is not None: + assert isinstance(self.gen_specs.gen_f, Callable), "Generator function is not callable." return self diff --git a/libensemble/worker.py b/libensemble/worker.py index 10823ad8a..2282ef74a 100644 --- a/libensemble/worker.py +++ b/libensemble/worker.py @@ -166,8 +166,8 @@ def __init__( self.workerID = workerID self.libE_specs = libE_specs self.stats_fmt = libE_specs.get("stats_fmt", {}) - self.sim_runner = Runner(sim_specs) - self.gen_runner = Runner(gen_specs) + self.sim_runner = Runner.from_specs(sim_specs) + self.gen_runner = Runner.from_specs(gen_specs) self.runners = {EVAL_SIM_TAG: self.sim_runner.run, EVAL_GEN_TAG: self.gen_runner.run} self.calc_iter = {EVAL_SIM_TAG: 0, EVAL_GEN_TAG: 0} Worker._set_executor(self.workerID, self.comm) @@ -256,6 +256,7 @@ def _handle_calc(self, Work: dict, calc_in: npt.NDArray) -> (npt.NDArray, dict, try: logger.debug(f"Starting {enum_desc}: {calc_id}") + out = None calc = self.runners[calc_type] with timer: if self.EnsembleDirectory.use_calc_dirs(calc_type): @@ -279,8 +280,8 @@ def _handle_calc(self, Work: dict, calc_in: npt.NDArray) -> (npt.NDArray, dict, if tag in [STOP_TAG, PERSIS_STOP] and message is MAN_SIGNAL_FINISH: calc_status = MAN_SIGNAL_FINISH - if out: - if len(out) >= 3: # Out, persis_info, calc_status + if out is not None: + if not isinstance(out, np.ndarray) and len(out) >= 3: # Out, persis_info, calc_status calc_status = out[2] return out elif len(out) == 2: # Out, persis_info OR Out, calc_status