From 6c7c691fe72942c48ac29a832f919087410b89c0 Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 25 Oct 2024 11:42:03 -0600 Subject: [PATCH] chore: add batches and epochs to TFKerasTrial metrics (#10129) All of our other training loops now emit this, and asha searches depend on it. --- harness/determined/keras/_tf_keras_trial.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/harness/determined/keras/_tf_keras_trial.py b/harness/determined/keras/_tf_keras_trial.py index 470d34b1121..b4f77f81233 100644 --- a/harness/determined/keras/_tf_keras_trial.py +++ b/harness/determined/keras/_tf_keras_trial.py @@ -890,6 +890,13 @@ def _post_train_batch_end(self, num_inputs: int, logs: Dict) -> None: if self.env.experiment_config.average_training_metrics_enabled(): final_metrics = self._allreduce_logs(final_metrics) + # Inject batches and epochs into avg metrics. + # (this is after batches and possibly epochs have been updated) + final_metrics["batches"] = final_metrics.get( + "batches", self.multiplexer.state.total_batches + ) + final_metrics["epochs"] = final_metrics.get("epochs", self.multiplexer.state.epoch) + self.multiplexer._train_workload_end(final_metrics) self._stop_training_check() @@ -962,6 +969,11 @@ def _compute_validation_metrics(self) -> workload.Response: step_duration = time.time() - validation_start_time logger.info(det.util.make_timing_log("validated", step_duration, num_inputs, num_batches)) + # Inject batches and epochs into validation metrics. + # (this is after batches and possibly epochs have been updated) + metrics["batches"] = metrics.get("batches", self.multiplexer.state.total_batches) + metrics["epochs"] = metrics.get("epochs", self.multiplexer.state.epoch) + self.metric_writer.on_validation_step_end(self.steps_completed, metrics) self.upload_tb_files() return {"num_inputs": num_inputs, "validation_metrics": metrics}