Skip to content

Commit

Permalink
Tests for converting results to numerics (microsoft#569)
Browse files Browse the repository at this point in the history
Adds tests to microsoft#567
  • Loading branch information
bpkroth authored Oct 27, 2023
1 parent 6d6f897 commit 92519d6
Show file tree
Hide file tree
Showing 16 changed files with 94 additions and 45 deletions.
2 changes: 1 addition & 1 deletion mlos_bench/mlos_bench/environments/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def teardown(self) -> None:
assert self._in_context
self._is_ready = False

def run(self) -> Tuple[Status, Optional[Dict[str, float]]]:
def run(self) -> Tuple[Status, Optional[Dict[str, TunableValue]]]:
"""
Execute the run script for this environment.
Expand Down
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/environments/composite_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mlos_bench.services.base_service import Service
from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -178,7 +179,7 @@ def teardown(self) -> None:
env_context.teardown()
super().teardown()

def run(self) -> Tuple[Status, Optional[Dict[str, float]]]:
def run(self) -> Tuple[Status, Optional[Dict[str, TunableValue]]]:
"""
Submit a new experiment to the environment.
Return the result of the *last* child environment if successful,
Expand Down
5 changes: 3 additions & 2 deletions mlos_bench/mlos_bench/environments/local/local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mlos_bench.environments.script_env import ScriptEnv
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import path_join

Expand Down Expand Up @@ -151,7 +152,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -

return self._is_ready

def run(self) -> Tuple[Status, Optional[Dict[str, float]]]:
def run(self) -> Tuple[Status, Optional[Dict[str, TunableValue]]]:
"""
Run a script in the local scheduler environment.
Expand All @@ -169,7 +170,7 @@ def run(self) -> Tuple[Status, Optional[Dict[str, float]]]:

assert self._temp_dir is not None

stdout_data: Dict[str, float] = {}
stdout_data: Dict[str, TunableValue] = {}
if self._script_run:
(return_code, output) = self._local_exec(self._script_run, self._temp_dir)
if return_code != 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _download_files(self, ignore_missing: bool = False) -> None:
if not ignore_missing:
raise ex

def run(self) -> Tuple[Status, Optional[Dict[str, float]]]:
def run(self) -> Tuple[Status, Optional[Dict[str, TunableValue]]]:
"""
Download benchmark results from the shared storage
and run post-processing scripts locally.
Expand Down
4 changes: 2 additions & 2 deletions mlos_bench/mlos_bench/environments/mock_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mlos_bench.services.base_service import Service
from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
from mlos_bench.tunables import Tunable, TunableGroups
from mlos_bench.tunables import Tunable, TunableGroups, TunableValue

_LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self,
self._metrics = self.config.get("metrics", ["score"])
self._is_ready = True

def run(self) -> Tuple[Status, Optional[Dict[str, float]]]:
def run(self) -> Tuple[Status, Optional[Dict[str, TunableValue]]]:
"""
Produce mock benchmark data for one experiment.
Expand Down
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/environments/remote/remote_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.services.types.host_ops_type import SupportsHostOps
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -110,7 +111,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -

return self._is_ready

def run(self) -> Tuple[Status, Optional[Dict[str, float]]]:
def run(self) -> Tuple[Status, Optional[Dict[str, TunableValue]]]:
"""
Runs the run script on the remote environment.
Expand Down
9 changes: 6 additions & 3 deletions mlos_bench/mlos_bench/environments/script_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@

from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups

from mlos_bench.util import try_parse_val

_LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -100,7 +103,7 @@ def _get_env_params(self) -> Dict[str, str]:

return {key_sub: str(self._params[key]) for (key_sub, key) in rename.items()}

def _extract_stdout_results(self, stdout: str) -> Dict[str, float]:
def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]:
"""
Extract the results from the stdout of the script.
Expand All @@ -111,10 +114,10 @@ def _extract_stdout_results(self, stdout: str) -> Dict[str, float]:
Returns
-------
results : Dict[str, float]
results : Dict[str, TunableValue]
A dictionary of results extracted from the stdout.
"""
if not self._results_stdout_pattern:
return {}
_LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout)
return {key: float(val) for (key, val) in self._results_stdout_pattern.findall(stdout)}
return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)}
33 changes: 3 additions & 30 deletions mlos_bench/mlos_bench/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type

from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.util import BaseTypeVar
from mlos_bench.util import BaseTypeVar, try_parse_val

from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
Expand Down Expand Up @@ -202,33 +202,6 @@ def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> T

return (args, args_rest)

@staticmethod
def _try_parse_val(val: str) -> TunableValue:
"""
Try to parse the value as an int or float, otherwise return the string.
This can help with config schema validation to make sure early on that
the args we're expecting are the right type.
Parameters
----------
val : str
The initial cmd line arg value.
Returns
-------
TunableValue
The parsed value.
"""
try:
if "." in val:
return float(val)
else:
return int(val)
except ValueError:
pass
return str(val)

@staticmethod
def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]:
"""
Expand All @@ -245,12 +218,12 @@ def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]:
key = elem[2:]
kv_split = key.split("=", 1)
if len(kv_split) == 2:
config[kv_split[0].strip()] = Launcher._try_parse_val(kv_split[1])
config[kv_split[0].strip()] = try_parse_val(kv_split[1])
key = None
else:
if key is None:
raise ValueError("Command line argument has no key: " + elem)
config[key.strip()] = Launcher._try_parse_val(elem)
config[key.strip()] = try_parse_val(elem)
key = None

if key is not None:
Expand Down
5 changes: 4 additions & 1 deletion mlos_bench/mlos_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def _run(env_context: Environment, opt: Optimizer,

# FIXME: Use the actual timestamp from the benchmark.
trial.update(status, datetime.utcnow(), results)
opt.register(trial.tunables, status, results)
# Filter out non-numeric scores from the optimizer.
scores = results if not isinstance(results, dict) \
else {k: float(v) for (k, v) in results.items() if isinstance(v, (int, float))}
opt.register(trial.tunables, status, scores)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/tests/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

from mlos_bench.environments.base_environment import Environment

from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups


def check_env_success(env: Environment,
tunable_groups: TunableGroups,
expected_results: Dict[str, Union[float, str]],
expected_results: Dict[str, Union[TunableValue]],
expected_telemetry: List[Tuple[datetime, str, Any]],
global_config: Optional[dict] = None) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions mlos_bench/mlos_bench/tests/environments/local/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlos_bench.environments.local.local_env import LocalEnv
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.services.local.local_exec import LocalExecService
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Unit tests for extracting data from LocalEnv stdout.
"""

import sys

from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.tests.environments import check_env_success
from mlos_bench.tests.environments.local import create_local_env
Expand Down Expand Up @@ -45,12 +47,14 @@ def test_local_env_file_stdout(tunable_groups: TunableGroups) -> None:
"echo 'latency,111'",
"echo 'throughput,222'",
"echo 'score,0.999'",
"echo 'stdout-msg,string'",
"echo '-------------------'", # Should be ignored
"echo 'metric,value' > output.csv",
"echo 'extra1,333' >> output.csv",
"echo 'extra2,444' >> output.csv",
"echo 'file-msg,string' >> output.csv",
],
"results_stdout_pattern": r"(\w+),([0-9.]+)",
"results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)",
"read_results_file": "output.csv",
})

Expand All @@ -60,8 +64,10 @@ def test_local_env_file_stdout(tunable_groups: TunableGroups) -> None:
"latency": 111.0,
"throughput": 222.0,
"score": 0.999,
"stdout-msg": "string",
"extra1": 333.0,
"extra2": 444.0,
"file-msg": "string " if sys.platform == "win32" else "string",
},
expected_telemetry=[],
)
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]:
assert status.is_succeeded()
assert output is not None
score = output['score']
assert isinstance(score, float)
assert 60 <= score <= 120
logger("score: %s", str(score))

Expand Down
26 changes: 26 additions & 0 deletions mlos_bench/mlos_bench/tests/util_try_parse_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Unit tests for try_parse_val utility function.
"""

import math

from mlos_bench.util import try_parse_val


def test_try_parse_val() -> None:
"""
Check that we can retrieve git info about the current repository correctly.
"""
assert try_parse_val(None) is None
assert try_parse_val("1") == int(1)
assert try_parse_val("1.1") == float(1.1)
assert try_parse_val("1e6") == float(1e6)
res = try_parse_val("NaN")
assert isinstance(res, float) and math.isnan(res)
res = try_parse_val("inf")
assert isinstance(res, float) and math.isinf(res)
assert try_parse_val("str") == str("str")
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/tunables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
Tunables classes for Environments in mlos_bench.
"""

from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable import Tunable, TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups

__all__ = [
'Tunable',
'TunableValue',
'TunableGroups',
]
31 changes: 31 additions & 0 deletions mlos_bench/mlos_bench/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,34 @@ def get_git_info(path: str = __file__) -> Tuple[str, str, str]:
_LOG.debug("Current git branch: %s %s", git_repo, git_commit)
rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root))
return (git_repo, git_commit, rel_path.replace("\\", "/"))


# Note: to avoid circular imports, we don't specify TunableValue here.
def try_parse_val(val: Optional[str]) -> Optional[Union[int, float, str]]:
"""
Try to parse the value as an int or float, otherwise return the string.
This can help with config schema validation to make sure early on that
the args we're expecting are the right type.
Parameters
----------
val : str
The initial cmd line arg value.
Returns
-------
TunableValue
The parsed value.
"""
if val is None:
return val
try:
val_float = float(val)
try:
val_int = int(val)
return val_int if val_int == val_float else val_float
except ValueError:
return val_float
except ValueError:
return str(val)

0 comments on commit 92519d6

Please sign in to comment.