Skip to content

Commit

Permalink
Fix two bugs in Ensemble class (#1162)
Browse files Browse the repository at this point in the history
* ensure that if we did a temporary cast of a new libE_specs to dict, it gets cast back to a class upon setting. Passthrough nworkers attribute upon setting libE_specs if the Ensemble doesn't have it, but libE_specs does

* fixes and improvements for libE_specs setting and nworkers attributes passthrough, plus a test

* didn't need to do this kind of "passthrough" since the nworkers property returns the correct value
  • Loading branch information
jlnav authored Nov 27, 2023
1 parent 6044725 commit f107c66
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
37 changes: 21 additions & 16 deletions libensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __init__(
self.sim_specs = sim_specs
self.gen_specs = gen_specs
self.exit_criteria = exit_criteria
self._libE_specs = libE_specs
self.libE_specs = libE_specs
self.alloc_specs = alloc_specs
self.persis_info = persis_info
self.executor = executor
Expand All @@ -284,7 +284,7 @@ def __init__(
self.logger = logger
self.logger.set_level("INFO")

self.nworkers = 0
self._nworkers = 0
self.is_manager = False
self.parsed = False

Expand Down Expand Up @@ -316,23 +316,23 @@ def libE_specs(self, new_specs):
# "not" overwrite the internal libE_specs["comms"], but *only* if parse_args
# was called. Otherwise we can respect the complete set of provided options.

# Convert our libE_specs from dict to class, if its a dict
if isinstance(self._libE_specs, dict):
self._libE_specs = LibeSpecs(**self._libE_specs)
# Respect everything if libE_specs isn't set
if not hasattr(self, "_libE_specs") or not self._libE_specs:
if isinstance(new_specs, dict):
self._libE_specs = LibeSpecs(**new_specs)
else:
self._libE_specs = new_specs
return

# Cast new libE_specs temporarily to dict
if not isinstance(new_specs, dict):
if isinstance(new_specs, LibeSpecs):
new_specs = new_specs.dict(by_alias=True, exclude_none=True, exclude_unset=True)

# Unset "comms" if we already have a libE_specs that contains that field, that came from parse_args
if new_specs.get("comms") and hasattr(self._libE_specs, "comms") and self.parsed:
new_specs.pop("comms")

# Now finally set attribute if we don't have a libE_specs, otherwise update the internal
if not self._libE_specs:
self._libE_specs = new_specs
else:
self._libE_specs.__dict__.update(**new_specs)
self._libE_specs.__dict__.update(**new_specs)

def _refresh_executor(self):
Executor.executor = self.executor or Executor.executor
Expand Down Expand Up @@ -390,10 +390,15 @@ def run(self) -> (npt.NDArray, dict, int):

return self.H, self.persis_info, self.flag

def _nworkers(self):
if self.nworkers:
return self.nworkers
return self.libE_specs.nworkers
@property
def nworkers(self):
return self._nworkers or self.libE_specs.nworkers

@nworkers.setter
def nworkers(self, value):
self._nworkers = value
if self.libE_specs:
self.libE_specs.nworkers = value

def _get_func(self, loaded):
"""Extracts user function specified in loaded dict"""
Expand Down Expand Up @@ -532,7 +537,7 @@ def add_random_streams(self, num_streams: int = 0, seed: str = ""):
if num_streams:
nstreams = num_streams
else:
nstreams = self._nworkers()
nstreams = self.nworkers

self.persis_info = add_unique_random_streams(self.persis_info, nstreams + 1, seed=seed)
return self.persis_info
Expand Down
22 changes: 22 additions & 0 deletions libensemble/tests/unit_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,31 @@ def test_ensemble_init():

e = Ensemble(parse_args=True)
assert hasattr(e.libE_specs, "comms"), "internal parse_args() didn't populate defaults for class's libE_specs"
assert hasattr(e, "nworkers"), "nworkers should've passed from libE_specs to Ensemble class"
assert e.is_manager, "parse_args() didn't populate defaults for class's libE_specs"

assert e.logger.get_level() == 20, "Default log level should be 20."


def test_ensemble_parse_args_false():
from libensemble.ensemble import Ensemble
from libensemble.specs import LibeSpecs

e = Ensemble() # parse_args defaults to False
e.libE_specs = {"comms": "local", "nworkers": 4}
assert hasattr(e, "nworkers"), "nworkers should've passed from libE_specs to Ensemble class"
assert isinstance(e.libE_specs, LibeSpecs), "libE_specs should've been cast to class"

# test pass attribute as dict
e = Ensemble(libE_specs={"comms": "local", "nworkers": 4})
assert hasattr(e, "nworkers"), "nworkers should've passed from libE_specs to Ensemble class"
assert isinstance(e.libE_specs, LibeSpecs), "libE_specs should've been cast to class"

# test that adjusting Ensemble.nworkers also changes libE_specs
e.nworkers = 8
assert e.libE_specs.nworkers == 8, "libE_specs nworkers not adjusted"


def test_from_files():
"""Test that Ensemble() specs dicts resemble setup dicts"""
from libensemble.ensemble import Ensemble
Expand Down Expand Up @@ -91,6 +111,7 @@ def test_full_workflow():
),
exit_criteria=ExitCriteria(gen_max=101),
)

ens.add_random_streams()
ens.run()
if ens.is_manager:
Expand Down Expand Up @@ -146,6 +167,7 @@ def test_flakey_workflow():

if __name__ == "__main__":
test_ensemble_init()
test_ensemble_parse_args_false()
test_from_files()
test_bad_func_loads()
test_full_workflow()
Expand Down

0 comments on commit f107c66

Please sign in to comment.