Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom functions: e.g. custom metrics #141

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion amlb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class TaskConfig:

def __init__(self, name, fold, metrics, seed,
max_runtime_seconds, cores, max_mem_size_mb, min_vol_size_mb,
input_dir, output_dir):
input_dir, output_dir, extensions):
self.framework = None
self.framework_params = None
self.type = None
Expand All @@ -295,6 +295,7 @@ def __init__(self, name, fold, metrics, seed,
self.input_dir = input_dir
self.output_dir = output_dir
self.output_predictions_file = os.path.join(output_dir, "predictions.csv")
self.extensions = extensions

def __json__(self):
return self.__dict__
Expand Down Expand Up @@ -350,6 +351,7 @@ def __init__(self, benchmark: Benchmark, task_def, fold):
min_vol_size_mb=task_def.min_vol_size_mb,
input_dir=rconfig().input_dir,
output_dir=benchmark.output_dirs.session,
extensions=rconfig().extensions_files,
)
# allowing to override some task parameters through command line, e.g.: -Xt.max_runtime_seconds=60
if rconfig()['t'] is not None:
Expand Down
8 changes: 8 additions & 0 deletions amlb/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .datautils import accuracy_score, confusion_matrix, f1_score, log_loss, balanced_accuracy_score, mean_absolute_error, mean_squared_error, mean_squared_log_error, r2_score, roc_auc_score, read_csv, write_csv, is_data_frame, to_data_frame
from .resources import get as rget, config as rconfig, output_dirs
from .utils import Namespace, backup_file, cached, datetime_iso, memoize, profile
from frameworks.shared.callee import get_extension

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -323,6 +324,13 @@ def __init__(self, predictions_df, info=None):
def evaluate(self, metric):
if hasattr(self, metric):
return getattr(self, metric)()
else:
# A metric may be defined twice, once for the automl system to use (e.g.
# as a scikit-learn scorer), and once in the amlb-compatible format.
# The amlb-compatible format is marked with a trailing underscore.
custom_metric = get_extension(rconfig().extensions_files, f"{metric}_")
if custom_metric is not None:
return custom_metric(self)
# raise ValueError("Metric {metric} is not supported for {type}.".format(metric=metric, type=self.type))
log.warning("Metric %s is not supported for %s!", metric, self.type)
return nan
Expand Down
5 changes: 3 additions & 2 deletions frameworks/AutoGluon/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from autogluon.utils.tabular.utils.savers import save_pd, save_pkl
import autogluon.utils.tabular.metrics as metrics

from frameworks.shared.callee import call_run, result, output_subdir, utils
from frameworks.shared.callee import call_run, get_extension, result, output_subdir, utils

log = logging.getLogger(__name__)

Expand All @@ -32,7 +32,8 @@ def run(dataset, config):
rmse=metrics.mean_squared_error, # for now, we can let autogluon optimize training on mse: anyway we compute final score from predictions.
)

perf_metric = metrics_mapping[config.metric] if config.metric in metrics_mapping else None
perf_metric = (metrics_mapping[config.metric] if config.metric in metrics_mapping
else get_extension(config.extensions, config.metric))
if perf_metric is None:
# TODO: figure out if we are going to blindly pass metrics through, or if we use a strict mapping
log.warning("Performance metric %s not supported.", config.metric)
Expand Down
5 changes: 3 additions & 2 deletions frameworks/TPOT/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
os.environ['MKL_NUM_THREADS'] = '1'
from tpot import TPOTClassifier, TPOTRegressor

from frameworks.shared.callee import call_run, result, output_subdir, utils
from frameworks.shared.callee import call_run, get_extension, result, output_subdir, utils


log = logging.getLogger(__name__)
Expand All @@ -34,7 +34,8 @@ def run(dataset, config):
r2='r2',
rmse='neg_mean_squared_error', # TPOT can score on mse, as app computes rmse independently on predictions
)
scoring_metric = metrics_mapping[config.metric] if config.metric in metrics_mapping else None
scoring_metric = (metrics_mapping[config.metric] if config.metric in metrics_mapping
else get_extension(config.extensions, config.metric))
if scoring_metric is None:
raise ValueError("Performance metric {} not supported.".format(config.metric))

Expand Down
5 changes: 3 additions & 2 deletions frameworks/autosklearn/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import autosklearn.metrics as metrics
from packaging import version

from frameworks.shared.callee import call_run, result, output_subdir, utils
from frameworks.shared.callee import call_run, get_extension, result, output_subdir, utils

log = logging.getLogger(__name__)

Expand All @@ -36,7 +36,8 @@ def run(dataset, config):
rmse=metrics.mean_squared_error, # autosklearn can optimize on mse, and we compute rmse independently on predictions
r2=metrics.r2
)
perf_metric = metrics_mapping[config.metric] if config.metric in metrics_mapping else None
perf_metric = (metrics_mapping[config.metric] if config.metric in metrics_mapping
else get_extension(config.extensions, config.metric))
if perf_metric is None:
# TODO: figure out if we are going to blindly pass metrics through, or if we use a strict mapping
log.warning("Performance metric %s not supported.", config.metric)
Expand Down
36 changes: 34 additions & 2 deletions frameworks/shared/callee.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import linecache
import json
import logging
import os
Expand Down Expand Up @@ -44,7 +45,38 @@ def output_subdir(name, config):
return subdir


data_keys = re.compile("^(X|y|data)(_.+)?$")
_extensions_ = {}


def get_extension(files, name=None, default=None):
files = [files] if isinstance(files, str) else files

extensions = []
for file in files:
if file in _extensions_:
extensions.append(_extensions_.get(file, {}))
elif os.path.isfile(file):
try:
with open(file) as f:
# linecache and compile are necessary only if we want to inspect code later
# otherwise the following statement is enough:
# exec(f.read(), customizations)
linecache.updatecache(f.name)
code = compile(f.read(), f.name, 'exec')
ext = {}
exec(code, ext)
_extensions_[file] = ext
extensions.append(ext)
except Exception as e:
log.warning("Could not load extension file %s: %s", file, str(e))
_extensions_[file] = {}
else:
log.warning("No extensions available at %s", file)

return extensions if name is None else next((ext[name] for ext in extensions if name in ext), default)


_data_keys_ = re.compile("^(X|y|data)(_.+)?$")


def call_run(run_fn):
Expand All @@ -53,7 +85,7 @@ def call_run(run_fn):
params = NS.from_dict(json.loads(sys.stdin.read()))

def load_data(name, path, **ignored):
if isinstance(path, str) and data_keys.match(name):
if isinstance(path, str) and _data_keys_.match(name):
return name, np.load(path, allow_pickle=True)
return name, path

Expand Down
3 changes: 3 additions & 0 deletions resources/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ benchmarks:
max_mem_size_mb: -1 # default amount of memory assigned to each automl task. If <= 0, then the amount of memory is computed from os available memory.
min_vol_size_mb: -1 # default minimum amount of free space required on the volume. If <= 0, skips verification.

extensions_files:
- '{user}/extensions.py'

results:
error_max_length: 200
save: true # set by runbenchmark.py
Expand Down