From 8848d7514d24b3319df9c786f7e4b53819b78e30 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 15 Sep 2023 15:56:41 +0200 Subject: [PATCH] Refactor torch training test to stop using metrics from checkpoint --- tests/torch/test_compression_training.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/torch/test_compression_training.py b/tests/torch/test_compression_training.py index a58b2563caa..405579d8b8d 100644 --- a/tests/torch/test_compression_training.py +++ b/tests/torch/test_compression_training.py @@ -174,9 +174,7 @@ def finalize(self, dataset_dir, tmp_path_factory, weekly_models_path) -> "Compre return self def get_metric(self): - return self.sample_handler.get_metric_value_from_checkpoint( - self.checkpoint_save_dir, self.checkpoint_name, self.config_path - ) + return self.expected_accuracy_ def _get_weight_path(self, weekly_models_path): if self.weights_filename_ is None: @@ -249,9 +247,7 @@ def subnet_expected_accuracy(self, subnet_expected_accuracy: float): return self def get_subnet_metric(self): - return self.sample_handler.get_metric_value_from_checkpoint( - self.checkpoint_save_dir, self.subnet_checkpoint_name - ) + return self.subnet_expected_accuracy_ def _get_weight_path(self, weekly_models_path): return os.path.join(