Skip to content

Commit

Permalink
chore: add batches and epochs to TFKerasTrial metrics (#10129)
Browse files Browse the repository at this point in the history
All of our other training loops now emit this, and asha searches depend on it.
  • Loading branch information
rb-determined-ai authored Oct 25, 2024
1 parent 544c404 commit 6c7c691
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions harness/determined/keras/_tf_keras_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,13 @@ def _post_train_batch_end(self, num_inputs: int, logs: Dict) -> None:
if self.env.experiment_config.average_training_metrics_enabled():
final_metrics = self._allreduce_logs(final_metrics)

# Inject batches and epochs into avg metrics.
# (this is after batches and possibly epochs have been updated)
final_metrics["batches"] = final_metrics.get(
"batches", self.multiplexer.state.total_batches
)
final_metrics["epochs"] = final_metrics.get("epochs", self.multiplexer.state.epoch)

self.multiplexer._train_workload_end(final_metrics)
self._stop_training_check()

Expand Down Expand Up @@ -962,6 +969,11 @@ def _compute_validation_metrics(self) -> workload.Response:
step_duration = time.time() - validation_start_time
logger.info(det.util.make_timing_log("validated", step_duration, num_inputs, num_batches))

# Inject batches and epochs into validation metrics.
# (this is after batches and possibly epochs have been updated)
metrics["batches"] = metrics.get("batches", self.multiplexer.state.total_batches)
metrics["epochs"] = metrics.get("epochs", self.multiplexer.state.epoch)

self.metric_writer.on_validation_step_end(self.steps_completed, metrics)
self.upload_tb_files()
return {"num_inputs": num_inputs, "validation_metrics": metrics}
Expand Down

0 comments on commit 6c7c691

Please sign in to comment.