Skip to content

Commit

Permalink
fix: keras continue from cloud checkpoint (#10192)
Browse files Browse the repository at this point in the history
What happened was an unfortunate combination of fancy context managers
and early exits, where I was unwittingly deleting checkpoints before
keras could read them, but only in the case of checkpoints from
different trial IDs.  I had tested pause/continue, which didn't have the
bug.

The new code is structurally incapable of the same bug, and has a
regression test anyway.
  • Loading branch information
rb-determined-ai authored Nov 1, 2024
1 parent 2afecfc commit ea1c694
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 30 deletions.
60 changes: 32 additions & 28 deletions harness/determined/keras/_callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import logging
import os
import pathlib
import pickle
import shutil
import tempfile
Expand Down Expand Up @@ -337,34 +338,8 @@ def _load(self, checkpoint: Optional[str]) -> Optional[contextlib.ExitStack]:
# Load model.
self.load_model(self.model, str(path / "model_checkpoint"), self._core.distributed)

# Load training state also.
state_path = path / "callback_state"
if not state_path.exists():
return None
with state_path.open("rb") as f:
state = pickle.load(f)
if state["continue_id"] != self._continue_id:
return None
# Continue training where we left off.
self._steps_completed = state["steps_completed"]
self._training_length = state["training_length"]
self._validation_length = state["validation_length"]
initial_epoch: int = state["epoch"] + 1

# HACK: Trick the training loop into starting on a different epoch. Internally, this is
# how keras.callbacks.BackupAndRestore() sets the initial_epoch.
class WorkerTrainingState:
# For tf.keras.
def maybe_load_initial_epoch_from_ckpt(*_: Any, **__: Any) -> int:
return initial_epoch

# For plain keras.
def maybe_load_initial_counters_from_ckpt(*_: Any, **__: Any) -> Tuple[int, int]:
# We only save on epoch boundaries.
initial_batch = 0
return initial_epoch, initial_batch

self.model._training_state = WorkerTrainingState()
# Load our own state.
self._load_training_state(path)

# Success! Don't delete the checkpoint until after the first batch runs though, because
# the checkpoint isn't actually read until then.
Expand All @@ -373,6 +348,35 @@ def maybe_load_initial_counters_from_ckpt(*_: Any, **__: Any) -> Tuple[int, int]
# mypy thinks it's possible to arrive here, but it isn't.
raise RuntimeError("impossible codepath")

def _load_training_state(self, path: pathlib.Path) -> None:
state_path = path / "callback_state"
if not state_path.exists():
return
with state_path.open("rb") as f:
state = pickle.load(f)
if state["continue_id"] != self._continue_id:
return
# Continue training where we left off.
self._steps_completed = state["steps_completed"]
self._training_length = state["training_length"]
self._validation_length = state["validation_length"]
initial_epoch: int = state["epoch"] + 1

# HACK: Trick the training loop into starting on a different epoch. Internally, this is
# how keras.callbacks.BackupAndRestore() sets the initial_epoch.
class WorkerTrainingState:
# For tf.keras.
def maybe_load_initial_epoch_from_ckpt(*_: Any, **__: Any) -> int:
return initial_epoch

# For plain keras.
def maybe_load_initial_counters_from_ckpt(*_: Any, **__: Any) -> Tuple[int, int]:
# We only save on epoch boundaries.
initial_batch = 0
return initial_epoch, initial_batch

self.model._training_state = WorkerTrainingState()

def save_model(
self, model: models.Model, path: str, distributed: core.DistributedContext
) -> None:
Expand Down
31 changes: 29 additions & 2 deletions harness/tests/experiment/keras/test_callback.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import contextlib
import json
import os
import pathlib
import re
import subprocess
import sys
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union
from unittest import mock

import keras
Expand All @@ -30,8 +31,21 @@ def mock_core_context(
"""
# Set up a functional DistributedContext.
distributed = distributed or core.DummyDistributedContext()

# Set up a functional CheckpointContext.
storage_manager = storage.SharedFSStorageManager(path)
class StorageManagerForTesting(storage.SharedFSStorageManager):
@contextlib.contextmanager
def restore_path(
self, src: str, selector: Optional[storage.Selector] = None
) -> Iterator[pathlib.Path]:
events.append(("restore_path:enter", None))
try:
with super().restore_path(src, selector) as x:
yield x
finally:
events.append(("restore_path:exit", None))

storage_manager = StorageManagerForTesting(path)
checkpoint = core.DummyCheckpointContext(distributed, storage_manager)

# Mock everything else, logging report-like calls to events.
Expand Down Expand Up @@ -74,6 +88,7 @@ class DeterminedCallbackForTesting(det.keras.DeterminedCallback):

def __init__(self, events: utils.Events, *args: Any, **kwargs: Any) -> None:
self.events = events
self.first_train_batch_end = False
super().__init__(*args, **kwargs)

def on_train_begin(self, logs: Any) -> None:
Expand All @@ -82,6 +97,12 @@ def on_train_begin(self, logs: Any) -> None:
fourdigits = "%.4f" % weight
self.events.append((f"after_train_begin:{fourdigits}", weight))

def on_train_batch_end(self, batch: int, logs: Any) -> None:
if not self.first_train_batch_end:
self.first_train_batch_end = True
self.events.append(("first_train_batch_end", None))
super().on_train_batch_end(batch, logs)

def on_epoch_end(self, epoch: int, logs: Any) -> None:
self.events.append((f"before_epoch_end:{epoch}", logs))
super().on_epoch_end(epoch, logs)
Expand Down Expand Up @@ -250,12 +271,15 @@ def test_save_restore_and_warm_start(tmp_path: pathlib.Path, eager: bool) -> Non
# - initial weight is nonzero (checkpoint was loaded)
# - initial epoch is nonzero (training state was loaded)
# - steps_completed was properly restored
# - checkpoint is not destoyed until first batch is completed
events = do_fit(tmp_path, eager=eager, checkpoint=ckpt, continue_id=1)
utils.assert_events_match(
events,
"set_status:restoring",
"load_model",
"after_train_begin:%.4f" % weight,
"first_train_batch_end",
"restore_path:exit",
"!after_epoch_end:0",
"before_epoch_end:1",
"report_metrics:training:16",
Expand All @@ -267,12 +291,15 @@ def test_save_restore_and_warm_start(tmp_path: pathlib.Path, eager: bool) -> Non
# - initial weight is nonzero (no checkpoint was loaded)
# - initial epoch is zero (no training state was loaded)
# - steps_completed was properly reset
# - checkpoint is not destoyed until first batch is completed
events = do_fit(tmp_path, eager=eager, checkpoint=ckpt, continue_id=2)
utils.assert_events_match(
events,
"set_status:restoring",
"load_model",
"after_train_begin:%.4f" % weight,
"first_train_batch_end",
"restore_path:exit",
"report_metrics:training:8",
"after_epoch_end:0",
"after_epoch_end:1",
Expand Down

0 comments on commit ea1c694

Please sign in to comment.