Skip to content

Commit

Permalink
lf.eval.v2 to support PerExampleCheckpointer.
Browse files Browse the repository at this point in the history
This allows us to save expensive runs at a finer granularity, so that restoration can recover saved examples even in the event of unexpected errors that may cause file corruption.

PiperOrigin-RevId: 706821039
  • Loading branch information
daiyip authored and langfun authors committed Dec 16, 2024
1 parent 51cb7d5 commit 371a8ef
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 26 deletions.
6 changes: 5 additions & 1 deletion langfun/core/eval/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@
from langfun.core.eval.v2 import metrics

from langfun.core.eval.v2.experiment import Plugin

from langfun.core.eval.v2.experiment import Runner
from langfun.core.eval.v2 import runners

# Plugins
from langfun.core.eval.v2.checkpointing import BulkCheckpointer
from langfun.core.eval.v2.checkpointing import PerExampleCheckpointer
from langfun.core.eval.v2.reporting import HtmlReporter


# pylint: enable=g-bad-import-order
# pylint: enable=g-importing-member
112 changes: 96 additions & 16 deletions langfun/core/eval/v2/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Checkpointing evaluation runs."""
import threading

import langfun.core as lf
from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import experiment as experiment_lib
import pyglove as pg
Expand All @@ -24,21 +25,100 @@


class Checkpointer(experiment_lib.Plugin):
"""Plugin for checkpointing evaluation runs."""
"""Base class for checkpointing evaluation examples."""


class PerExampleCheckpointer(Checkpointer):
"""Checkpointer that saves each example to a separate file."""

checkpoint_filename: str = 'checkpoint.bagz'

def _on_bound(self):
super()._on_bound()
prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename)
self._checkpoint_file_prefix = prefix
self._checkpoint_file_ext = ext

def on_experiment_start(
self,
runner: Runner,
experiment: Experiment,
) -> None:
"""Creates the checkpoint file."""
if not experiment.is_leaf:
return

# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
def _load_state(ckpt_file):
experiment.load_state(ckpt_file)

experiment_dir = runner.current_run.input_dir(experiment)
if pg.io.path_exists(experiment_dir):
ckpt_files = [
runner.current_run.input_path_for(experiment, filename)
for filename in pg.io.listdir(experiment_dir)
if filename.startswith(self._checkpoint_file_prefix)
and filename.endswith(self._checkpoint_file_ext)
]
else:
ckpt_files = []

for ckpt_file, _, error in lf.concurrent_map(
_load_state, ckpt_files, max_workers=64,
):
if error is not None:
pg.logging.warning(
'Failed to load checkpoint file %s: %s. Skipping the file.',
ckpt_file, error
)

def on_example_complete(
self,
runner: Runner,
experiment: Experiment,
example: Example,
) -> None:
"""Saves the example to the checkpoint file."""
if not example.has_error:
def save_state(example: Example):
writer = SequenceWriter(
runner.current_run.output_path_for(
experiment,
(
f'{self._checkpoint_file_prefix}_{example.id}'
f'{self._checkpoint_file_ext}'
)
)
)
writer.add(example)
del writer
runner.background_run(save_state, example)

def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]:
ext_index = filename.rfind('.')
if ext_index == -1:
return filename, ''
else:
return filename[:ext_index], filename[ext_index:]


class BulkCheckpointer(Checkpointer):
"""Checkpointer that saves all examples to a single file."""

checkpoint_filename: str = 'checkpoint.bagz'

def _on_bound(self):
super()._on_bound()
self._lock = threading.Lock()
self._state_writer = None
self._sequence_writer = None

def on_run_start(
self,
runner: Runner,
root: Experiment,
) -> None:
self._state_writer = {}
self._sequence_writer = {}

def on_run_abort(
self,
Expand All @@ -47,16 +127,16 @@ def on_run_abort(
error: BaseException
) -> None:
with self._lock:
if self._state_writer is not None:
self._state_writer.clear()
if self._sequence_writer is not None:
self._sequence_writer.clear()

def on_run_complete(
self,
runner: Runner,
root: Experiment,
) -> None:
with self._lock:
assert self._state_writer is not None and not self._state_writer
assert self._sequence_writer is not None and not self._sequence_writer

def on_experiment_start(
self,
Expand All @@ -74,14 +154,14 @@ def on_experiment_start(
),
raise_if_not_exist=False
)
state_writer = StateWriter(
sequence_writer = SequenceWriter(
runner.current_run.output_path_for(
experiment, self.checkpoint_filename
)
)
with self._lock:
if self._state_writer is not None:
self._state_writer[experiment.id] = state_writer
if self._sequence_writer is not None:
self._sequence_writer[experiment.id] = sequence_writer

def on_experiment_complete(
self,
Expand All @@ -91,10 +171,10 @@ def on_experiment_complete(
"""Closes the checkpoint file."""
if not experiment.is_leaf:
return
assert experiment.id in self._state_writer
assert experiment.id in self._sequence_writer
with self._lock:
if self._state_writer is not None:
del self._state_writer[experiment.id]
if self._sequence_writer is not None:
del self._sequence_writer[experiment.id]

def on_example_complete(
self,
Expand All @@ -103,13 +183,13 @@ def on_example_complete(
example: Example,
) -> None:
"""Saves the example to the checkpoint file."""
assert experiment.id in self._state_writer
assert experiment.id in self._sequence_writer
if not example.has_error:
runner.background_run(self._state_writer[experiment.id].add, example)
runner.background_run(self._sequence_writer[experiment.id].add, example)


class StateWriter:
"""Thread safe state writer."""
class SequenceWriter:
"""Thread safe sequence writer."""

def __init__(self, path: str):
self._lock = threading.Lock()
Expand Down
49 changes: 41 additions & 8 deletions langfun/core/eval/v2/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@
Example = example_lib.Example


class StateWriterTest(unittest.TestCase):
class SequenceWriterTest(unittest.TestCase):

def test_basic(self):
file = os.path.join(tempfile.gettempdir(), 'test.jsonl')
writer = checkpointing.StateWriter(file)
writer = checkpointing.SequenceWriter(file)
example = Example(id=1, input=pg.Dict(x=1), output=2)
writer.add(example)
del writer
self.assertTrue(pg.io.path_exists(file))

def test_error_handling(self):
file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl')
writer = checkpointing.StateWriter(file)
writer = checkpointing.SequenceWriter(file)
writer.add(Example(id=1, input=pg.Dict(x=1), output=2))

def f():
Expand All @@ -52,17 +52,50 @@ def f():
self.assertEqual(len(list(iter(f))), 1)


class CheckpointingTest(unittest.TestCase):
class PerExampleCheckpointerTest(unittest.TestCase):

def test_checkpointing(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_checkpointing')
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
experiment = test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
checkpointer = checkpointing.Checkpointer(checkpoint_filename)
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
run = experiment.run(
root_dir, 'new', runner='sequential', plugins=[checkpointer]
)
self.assertEqual(len(checkpointer._state_writer), 0)
num_processed = {}
for leaf in experiment.leaf_nodes:
for i in range(leaf.num_examples):
example = leaf.state.get(i + 1)
ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
if example.has_error:
self.assertFalse(pg.io.path_exists(ckpt))
else:
self.assertTrue(pg.io.path_exists(ckpt))
with pg.io.open_sequence(ckpt) as f:
self.assertEqual(len(list(iter(f))), 1)
if leaf.id not in num_processed:
self.assertEqual(leaf.progress.num_skipped, 0)
num_processed[leaf.id] = leaf.progress.num_processed

# Run again, should skip existing.
_ = experiment.run(
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
)
for leaf in experiment.leaf_nodes:
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])


class BulkCheckpointerTest(unittest.TestCase):

def test_checkpointing(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
experiment = test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename)
run = experiment.run(
root_dir, 'new', runner='sequential', plugins=[checkpointer]
)
self.assertEqual(len(checkpointer._sequence_writer), 0)
num_processed = {}
for leaf in experiment.leaf_nodes:
ckpt = run.output_path_for(leaf, checkpoint_filename)
Expand All @@ -80,7 +113,7 @@ def test_checkpointing(self):
_ = experiment.run(
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
)
self.assertEqual(len(checkpointer._state_writer), 0)
self.assertEqual(len(checkpointer._sequence_writer), 0)
for leaf in experiment.leaf_nodes:
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])

Expand Down
2 changes: 1 addition & 1 deletion langfun/core/eval/v2/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class RunnerBase(Runner):
] = False

plugins = [
checkpointing.Checkpointer(),
checkpointing.BulkCheckpointer(),
reporting.HtmlReporter(),
]

Expand Down

0 comments on commit 371a8ef

Please sign in to comment.