From da17be26a6f9a9c7cb9cf57e25d469911ac0e185 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Mon, 16 Dec 2024 12:16:23 -0800 Subject: [PATCH] `lf.eval.v2` to support `PerExampleCheckpointer`. This allows us to save expensive runs at a finer grained level, so the restoration would be able to pick up saved examples even when unexpected error happens, which could cause file corrpution. PiperOrigin-RevId: 706792918 --- langfun/core/eval/v2/__init__.py | 6 +- langfun/core/eval/v2/checkpointing.py | 112 ++++++++++++++++++--- langfun/core/eval/v2/checkpointing_test.py | 47 +++++++-- langfun/core/eval/v2/runners.py | 2 +- 4 files changed, 142 insertions(+), 25 deletions(-) diff --git a/langfun/core/eval/v2/__init__.py b/langfun/core/eval/v2/__init__.py index 4fcb4a8..0ef19e0 100644 --- a/langfun/core/eval/v2/__init__.py +++ b/langfun/core/eval/v2/__init__.py @@ -29,10 +29,14 @@ from langfun.core.eval.v2 import metrics from langfun.core.eval.v2.experiment import Plugin - from langfun.core.eval.v2.experiment import Runner from langfun.core.eval.v2 import runners +# Plugins +from langfun.core.eval.v2.checkpointing import BulkCheckpointer +from langfun.core.eval.v2.checkpointing import PerExampleCheckpointer +from langfun.core.eval.v2.reporting import HtmlReporter + # pylint: enable=g-bad-import-order # pylint: enable=g-importing-member diff --git a/langfun/core/eval/v2/checkpointing.py b/langfun/core/eval/v2/checkpointing.py index bcac0b1..8c6f163 100644 --- a/langfun/core/eval/v2/checkpointing.py +++ b/langfun/core/eval/v2/checkpointing.py @@ -14,6 +14,7 @@ """Checkpointing evaluation runs.""" import threading +import langfun.core as lf from langfun.core.eval.v2 import example as example_lib from langfun.core.eval.v2 import experiment as experiment_lib import pyglove as pg @@ -24,21 +25,100 @@ class Checkpointer(experiment_lib.Plugin): - """Plugin for checkpointing evaluation runs.""" + """Base class for checkpointing evaluation examples.""" + + +class PerExampleCheckpointer(Checkpointer): + """Checkpointer that saves each example to a separate file.""" + + checkpoint_filename: str = 'checkpoint.bagz' + + def _on_bound(self): + super()._on_bound() + prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename) + self._checkpoint_file_prefix = prefix + self._checkpoint_file_ext = ext + + def on_experiment_start( + self, + runner: Runner, + experiment: Experiment, + ) -> None: + """Creates the checkpoint file.""" + if not experiment.is_leaf: + return + + # For refresh runs, we don't want to load the previous state. + if not runner.current_run.refresh: + def _load_state(ckpt_file): + experiment.load_state(ckpt_file) + + experiment_dir = runner.current_run.input_dir(experiment) + if pg.io.path_exists(experiment_dir): + ckpt_files = [ + runner.current_run.input_path_for(experiment, filename) + for filename in pg.io.listdir(experiment_dir) + if filename.startswith(self._checkpoint_file_prefix) + and filename.endswith(self._checkpoint_file_ext) + ] + else: + ckpt_files = [] + + for ckpt_file, _, error in lf.concurrent_map( + _load_state, ckpt_files, max_workers=64, + ): + if error is not None: + pg.logging.warning( + 'Failed to load checkpoint file %s: %s. Skipping the file.', + ckpt_file, error + ) + + def on_example_complete( + self, + runner: Runner, + experiment: Experiment, + example: Example, + ) -> None: + """Saves the example to the checkpoint file.""" + if not example.has_error: + def save_state(example: Example): + writer = SequenceWriter( + runner.current_run.output_path_for( + experiment, + ( + f'{self._checkpoint_file_prefix}_{example.id}' + f'{self._checkpoint_file_ext}' + ) + ) + ) + writer.add(example) + del writer + runner.background_run(save_state, example) + + def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]: + ext_index = filename.rfind('.') + if ext_index == -1: + return filename, '' + else: + return filename[:ext_index], filename[ext_index:] + + +class BulkCheckpointer(Checkpointer): + """Checkpointer that saves all examples to a single file.""" checkpoint_filename: str = 'checkpoint.bagz' def _on_bound(self): super()._on_bound() self._lock = threading.Lock() - self._state_writer = None + self._sequence_writer = None def on_run_start( self, runner: Runner, root: Experiment, ) -> None: - self._state_writer = {} + self._sequence_writer = {} def on_run_abort( self, @@ -47,8 +127,8 @@ def on_run_abort( error: BaseException ) -> None: with self._lock: - if self._state_writer is not None: - self._state_writer.clear() + if self._sequence_writer is not None: + self._sequence_writer.clear() def on_run_complete( self, @@ -56,7 +136,7 @@ def on_run_complete( root: Experiment, ) -> None: with self._lock: - assert self._state_writer is not None and not self._state_writer + assert self._sequence_writer is not None and not self._sequence_writer def on_experiment_start( self, @@ -74,14 +154,14 @@ def on_experiment_start( ), raise_if_not_exist=False ) - state_writer = StateWriter( + sequence_writer = SequenceWriter( runner.current_run.output_path_for( experiment, self.checkpoint_filename ) ) with self._lock: - if self._state_writer is not None: - self._state_writer[experiment.id] = state_writer + if self._sequence_writer is not None: + self._sequence_writer[experiment.id] = sequence_writer def on_experiment_complete( self, @@ -91,10 +171,10 @@ def on_experiment_complete( """Closes the checkpoint file.""" if not experiment.is_leaf: return - assert experiment.id in self._state_writer + assert experiment.id in self._sequence_writer with self._lock: - if self._state_writer is not None: - del self._state_writer[experiment.id] + if self._sequence_writer is not None: + del self._sequence_writer[experiment.id] def on_example_complete( self, @@ -103,13 +183,13 @@ def on_example_complete( example: Example, ) -> None: """Saves the example to the checkpoint file.""" - assert experiment.id in self._state_writer + assert experiment.id in self._sequence_writer if not example.has_error: - runner.background_run(self._state_writer[experiment.id].add, example) + runner.background_run(self._sequence_writer[experiment.id].add, example) -class StateWriter: - """Thread safe state writer.""" +class SequenceWriter: + """Thread safe sequence writer.""" def __init__(self, path: str): self._lock = threading.Lock() diff --git a/langfun/core/eval/v2/checkpointing_test.py b/langfun/core/eval/v2/checkpointing_test.py index 5e105f5..e0bde6f 100644 --- a/langfun/core/eval/v2/checkpointing_test.py +++ b/langfun/core/eval/v2/checkpointing_test.py @@ -28,7 +28,7 @@ class StateWriterTest(unittest.TestCase): def test_basic(self): file = os.path.join(tempfile.gettempdir(), 'test.jsonl') - writer = checkpointing.StateWriter(file) + writer = checkpointing.SequenceWriter(file) example = Example(id=1, input=pg.Dict(x=1), output=2) writer.add(example) del writer @@ -36,7 +36,7 @@ def test_basic(self): def test_error_handling(self): file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl') - writer = checkpointing.StateWriter(file) + writer = checkpointing.SequenceWriter(file) writer.add(Example(id=1, input=pg.Dict(x=1), output=2)) def f(): @@ -52,17 +52,50 @@ def f(): self.assertEqual(len(list(iter(f))), 1) -class CheckpointingTest(unittest.TestCase): +class PerExampleCheckpointerTest(unittest.TestCase): def test_checkpointing(self): - root_dir = os.path.join(tempfile.gettempdir(), 'test_checkpointing') + root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer') experiment = test_helper.test_experiment() checkpoint_filename = 'checkpoint.jsonl' - checkpointer = checkpointing.Checkpointer(checkpoint_filename) + checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename) run = experiment.run( root_dir, 'new', runner='sequential', plugins=[checkpointer] ) - self.assertEqual(len(checkpointer._state_writer), 0) + num_processed = {} + for leaf in experiment.leaf_nodes: + for i in range(leaf.num_examples): + example = leaf.state.get(i + 1) + ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl') + if example.has_error: + self.assertFalse(pg.io.path_exists(ckpt)) + else: + self.assertTrue(pg.io.path_exists(ckpt)) + with pg.io.open_sequence(ckpt) as f: + self.assertEqual(len(list(iter(f))), 1) + if leaf.id not in num_processed: + self.assertEqual(leaf.progress.num_skipped, 0) + num_processed[leaf.id] = leaf.progress.num_processed + + # Run again, should skip existing. + _ = experiment.run( + root_dir, 'latest', runner='sequential', plugins=[checkpointer] + ) + for leaf in experiment.leaf_nodes: + self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id]) + + +class BulkCheckpointerTest(unittest.TestCase): + + def test_checkpointing(self): + root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer') + experiment = test_helper.test_experiment() + checkpoint_filename = 'checkpoint.jsonl' + checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename) + run = experiment.run( + root_dir, 'new', runner='sequential', plugins=[checkpointer] + ) + self.assertEqual(len(checkpointer._sequence_writer), 0) num_processed = {} for leaf in experiment.leaf_nodes: ckpt = run.output_path_for(leaf, checkpoint_filename) @@ -80,7 +113,7 @@ def test_checkpointing(self): _ = experiment.run( root_dir, 'latest', runner='sequential', plugins=[checkpointer] ) - self.assertEqual(len(checkpointer._state_writer), 0) + self.assertEqual(len(checkpointer._sequence_writer), 0) for leaf in experiment.leaf_nodes: self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id]) diff --git a/langfun/core/eval/v2/runners.py b/langfun/core/eval/v2/runners.py index 4ac4362..6bc4656 100644 --- a/langfun/core/eval/v2/runners.py +++ b/langfun/core/eval/v2/runners.py @@ -53,7 +53,7 @@ class RunnerBase(Runner): ] = False plugins = [ - checkpointing.Checkpointer(), + checkpointing.BulkCheckpointer(), reporting.HtmlReporter(), ]