From 818d94474ca9bcd2e38d755b09b56e0f69d1ade1 Mon Sep 17 00:00:00 2001 From: Guan-JW Date: Thu, 10 Aug 2023 10:35:37 +0800 Subject: [PATCH] Fix the test keys in nasbench201 --- hpobench/benchmarks/nas/nasbench_201.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hpobench/benchmarks/nas/nasbench_201.py b/hpobench/benchmarks/nas/nasbench_201.py index 17bac321..6b062573 100644 --- a/hpobench/benchmarks/nas/nasbench_201.py +++ b/hpobench/benchmarks/nas/nasbench_201.py @@ -259,8 +259,8 @@ def objective_function(self, configuration: Union[CS.Configuration, Dict], for e in range(1, epoch + 1)) for seed in data_seed] # There is a single value for the eval data per seed. (only epoch 200) - test_accuracies = [self.data[seed][structure_str]['eval_acc1es'][f'{valid_key}@{199}'] for seed in data_seed] - test_losses = [self.data[seed][structure_str]['eval_losses'][f'{valid_key}@{199}'] for seed in data_seed] + test_accuracies = [self.data[seed][structure_str]['eval_acc1es'][f'{test_key}@{199}'] for seed in data_seed] + test_losses = [self.data[seed][structure_str]['eval_losses'][f'{test_key}@{199}'] for seed in data_seed] test_times = [np.sum((self.data[seed][structure_str]['eval_times'][f'{test_key}@{199}']) for e in range(1, epoch + 1)) for seed in data_seed]