diff --git a/langfun/core/eval/v2/checkpointing.py b/langfun/core/eval/v2/checkpointing.py index 65ca6c9..006afac 100644 --- a/langfun/core/eval/v2/checkpointing.py +++ b/langfun/core/eval/v2/checkpointing.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Checkpointing evaluation runs.""" +import abc import threading import traceback @@ -28,21 +29,74 @@ class Checkpointer(experiment_lib.Plugin): """Base class for checkpointing evaluation examples.""" - def on_experiment_start(self, experiment: Experiment): + def on_experiment_start( + self, + runner: Runner, + experiment: Experiment + ) -> None: + if not experiment.is_leaf: + return + + # For refresh runs, we don't want to load the previous state. + if not runner.current_run.refresh: + if runner.current_run.input_root != runner.current_run.output_root: + experiment.info( + f'Warm starting from directory: {runner.current_run.input_root}.' + ) + self._load_experiment(runner, experiment) + if experiment.state.evaluated_examples: + loaded_example_ids = list( + sorted(experiment.state.evaluated_examples.keys()) + ) + example_ids_to_evaluate = ( + set(runner.current_run.example_ids) if runner.current_run.example_ids + else set(range(1, experiment.num_examples + 1)) + ) + example_ids_to_evaluate -= set(loaded_example_ids) + experiment.info( - 'Loaded %d examples from checkpoint files. Example IDs: %s' % - ( - len(experiment.state.evaluated_examples), - list(sorted(experiment.state.evaluated_examples.keys())) - ), + f'{len(experiment.state.evaluated_examples)} examples have been ' + 'loaded from checkpoint files. Their outputs will be used ' + f'for recomputing metrics. Example IDs: {loaded_example_ids}' + ) + experiment.info( + f'{len(example_ids_to_evaluate)} examples will be processed from ' + f'scratch. Example IDs: {list(sorted(example_ids_to_evaluate))}' ) else: experiment.info( - 'No previous evaluated examples are loaded. ' + 'No examples are loaded from checkpoint files. ' f'Experiment {experiment.id} starts from scratch.' ) + def on_example_complete( + self, + runner: Runner, + experiment: Experiment, + example: Example, + ) -> None: + """Saves the example to the checkpoint file.""" + if example.has_error: + experiment.warning( + f'Example {example.id} has error. Skipping checkpointing.' + ) + else: + self._save_example(runner, experiment, example) + + @abc.abstractmethod + def _load_experiment(self, runner: Runner, experiment: Experiment) -> None: + """Loads the experiment state from checkpoint files.""" + + @abc.abstractmethod + def _save_example( + self, + runner: Runner, + experiment: Experiment, + example: Example, + ) -> None: + """Saves an evaluated example.""" + class PerExampleCheckpointer(Checkpointer): """Checkpointer that saves each example to a separate file.""" @@ -55,80 +109,86 @@ def _on_bound(self): self._checkpoint_file_prefix = prefix self._checkpoint_file_ext = ext - def on_experiment_start( + def _load_experiment( self, runner: Runner, experiment: Experiment, ) -> None: """Creates the checkpoint file.""" - if not experiment.is_leaf: - return + experiment_dir = runner.current_run.input_dir(experiment) + if pg.io.path_exists(experiment_dir): + ckpt_files = [ + runner.current_run.input_path_for(experiment, filename) + for filename in pg.io.listdir(experiment_dir) + if filename.startswith(self._checkpoint_file_prefix) + and filename.endswith(self._checkpoint_file_ext) + ] + else: + ckpt_files = [] - # For refresh runs, we don't want to load the previous state. - if not runner.current_run.refresh: - if runner.current_run.input_root != runner.current_run.output_root: - experiment.info( - f'Warm starting from directory: {runner.current_run.input_root}.' + experiment.info(f'Found {len(ckpt_files)} checkpoint files to load.') + + # Load the checkpoint files in parallel. + context = dict(counter=0, counter_lock=threading.Lock()) + def _load_state(ckpt_file): + error = None + with pg.timeit() as t: + try: + experiment.load_state(ckpt_file) + except BaseException as e: # pylint: disable=broad-except + error = e + finally: + with context['counter_lock']: + context['counter'] += 1 + + progress_str = f'{context["counter"]}/{len(ckpt_files)}' + if error is None: + experiment.info( + f'Loaded checkpoint file {ckpt_file} in {t.elapse:.2f} ' + f'seconds. ({progress_str})' + ) + else: + experiment.warning( + f'Failed to load checkpoint file {ckpt_file}: {error}. ' + f'Skipping the file. ({progress_str})' + ) + + _ = list( + lf.concurrent_map( + _load_state, ckpt_files, max_workers=16, silence_on_errors=None ) - def _load_state(ckpt_file): - experiment.load_state(ckpt_file) - - experiment_dir = runner.current_run.input_dir(experiment) - if pg.io.path_exists(experiment_dir): - ckpt_files = [ - runner.current_run.input_path_for(experiment, filename) - for filename in pg.io.listdir(experiment_dir) - if filename.startswith(self._checkpoint_file_prefix) - and filename.endswith(self._checkpoint_file_ext) - ] - else: - ckpt_files = [] - - for ckpt_file, _, error in lf.concurrent_map( - _load_state, ckpt_files, max_workers=64, - ): - if error is not None: - experiment.warning( - f'Failed to load checkpoint file {ckpt_file}: {error}. ' - 'Skipping the file.' - ) - super().on_experiment_start(experiment) + ) - def on_example_complete( + def _save_example( self, runner: Runner, experiment: Experiment, example: Example, ) -> None: """Saves the example to the checkpoint file.""" - if example.has_error: - experiment.warning( - f'Example {example.id} has error. Skipping checkpointing.' + def save_state(example: Example): + writer = SequenceWriter( + runner.current_run.output_path_for( + experiment, + ( + f'{self._checkpoint_file_prefix}_{example.id}' + f'{self._checkpoint_file_ext}' + ) + ) ) - else: - def save_state(example: Example): - writer = SequenceWriter( - runner.current_run.output_path_for( - experiment, - ( - f'{self._checkpoint_file_prefix}_{example.id}' - f'{self._checkpoint_file_ext}' - ) - ) + try: + writer.add(example) + writer.close() + experiment.info( + f'Example {example.id} saved to {writer.path}.', ) - try: - writer.add(example) - writer.close() - experiment.info( - f'Example {example.id} is saved to {writer.path}.', - ) - except BaseException as e: # pylint: disable=broad-except - experiment.error( - f'Failed to save example {example.id} to {writer.path}. ' - f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.', - ) - raise e - runner.background_run(save_state, example) + except BaseException as e: # pylint: disable=broad-except + experiment.error( + f'Failed to save example {example.id} to {writer.path}. ' + f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.', + ) + raise e + runner.background_run(save_state, example) def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]: ext_index = filename.rfind('.') @@ -180,30 +240,31 @@ def on_experiment_start( runner: Runner, experiment: Experiment, ) -> None: - """Creates the checkpoint file.""" - if not experiment.is_leaf: - return - # For refresh runs, we don't want to load the previous state. - if not runner.current_run.refresh: - if runner.current_run.input_root != runner.current_run.output_root: - experiment.info( - f'Warm starting from directory: {runner.current_run.input_root}.' - ) - experiment.load_state( - runner.current_run.input_path_for( + super().on_experiment_start(runner, experiment) + + # Prepare the sequence writer for the experiment. + if experiment.is_leaf: + sequence_writer = SequenceWriter( + runner.current_run.output_path_for( experiment, self.checkpoint_filename - ), - raise_if_not_exist=False + ) ) - sequence_writer = SequenceWriter( - runner.current_run.output_path_for( + with self._lock: + if self._sequence_writer is not None: + self._sequence_writer[experiment.id] = sequence_writer + + def _load_experiment( + self, + runner: Runner, + experiment: Experiment, + ) -> None: + """Creates the checkpoint file.""" + experiment.load_state( + runner.current_run.input_path_for( experiment, self.checkpoint_filename - ) + ), + raise_if_not_exist=False ) - with self._lock: - if self._sequence_writer is not None: - self._sequence_writer[experiment.id] = sequence_writer - super().on_experiment_start(experiment) def on_experiment_complete( self, @@ -225,7 +286,7 @@ def on_experiment_complete( f'checkpointed to {writer.path}.' ) - def on_example_complete( + def _save_example( self, runner: Runner, experiment: Experiment, @@ -233,22 +294,20 @@ def on_example_complete( ) -> None: """Saves the example to the checkpoint file.""" assert experiment.id in self._sequence_writer - if example.has_error: - experiment.warning( - f'Example {example.id} has error. Skipping checkpointing.' - ) - else: - def _save_example(example: Example): - writer = self._sequence_writer[experiment.id] - try: - writer.add(example) - except BaseException as e: # pylint: disable=broad-except - experiment.error( - f'Failed to save example {example.id} to {writer.path}. ' - f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.', - ) - raise e - runner.background_run(_save_example, example) + def _save_example(example: Example): + writer = self._sequence_writer[experiment.id] + try: + writer.add(example) + experiment.info( + f'Example {example.id} added to {writer.path}.', + ) + except BaseException as e: # pylint: disable=broad-except + experiment.error( + f'Failed to save example {example.id} to {writer.path}. ' + f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.', + ) + raise e + runner.background_run(_save_example, example) class SequenceWriter: diff --git a/langfun/core/eval/v2/checkpointing_test.py b/langfun/core/eval/v2/checkpointing_test.py index b1537a5..725e2a1 100644 --- a/langfun/core/eval/v2/checkpointing_test.py +++ b/langfun/core/eval/v2/checkpointing_test.py @@ -55,6 +55,7 @@ def f(): class PerExampleCheckpointerTest(unittest.TestCase): def test_checkpointing(self): + pg.defaults.loggers.use_stdout() root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer') experiment = eval_test_helper.test_experiment() checkpoint_filename = 'checkpoint.jsonl' diff --git a/langfun/core/eval/v2/experiment.py b/langfun/core/eval/v2/experiment.py index f9b86fe..a2d0a54 100644 --- a/langfun/core/eval/v2/experiment.py +++ b/langfun/core/eval/v2/experiment.py @@ -81,7 +81,7 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension): directory (using the ID 'latest'). Users can specify 'new' to start a fresh run or provide a specific run ID (typically in the format %Y%m%d_%). Additionally, when initiating a new run, users may specify a `warm_start_from` - ID to restore the experiment’s state from a previous run. + directory to restore the experiment’s state from a previous run. Examples: @@ -97,9 +97,9 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension): # Start a new, clean run. experiment.run(root_dir, 'new') - # Start a new run with a warm start from the previous run located in - # 'run_20241031_1' of the root directory. - experiment.run(root_dir, 'new', warm_start_from='20241031_1') + # Start a new run with a warm start from the another run located at + # '/path/to/another/run' (e.g. /my_expreriment/run_20241031_1). + experiment.run(root_dir, 'new', warm_start_from='/path/to/another/run') # Resume run '20241031_1', re-running failed examples and recomputing # metrics as needed. diff --git a/langfun/core/eval/v2/reporting.py b/langfun/core/eval/v2/reporting.py index ed2c6eb..72c0d4e 100644 --- a/langfun/core/eval/v2/reporting.py +++ b/langfun/core/eval/v2/reporting.py @@ -13,12 +13,14 @@ # limitations under the License. """Reporting evaluation results.""" +import threading import time import traceback from typing import Annotated from langfun.core.eval.v2 import example as example_lib from langfun.core.eval.v2 import experiment as experiment_lib +import pyglove as pg Runner = experiment_lib.Runner Experiment = experiment_lib.Experiment @@ -40,12 +42,15 @@ class HtmlReporter(experiment_lib.Plugin): experiment_report_interval: Annotated[ int, 'The interval of writing report for inidividual experiments in seconds.' - ] = 60 + ] = 120 def _on_bound(self): super()._on_bound() self._last_summary_time = 0 self._last_experiment_report_time = {} + self._update_thread = None + self._stop_update = False + self._stop_update_experiment_ids = set() def on_run_start( self, @@ -54,12 +59,19 @@ def on_run_start( ) -> None: self._maybe_update_summary(runner) self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes} + self._stop_update = False + self._stop_update_experiment_ids = set() + self._update_thread = threading.Thread( + target=self._update_thread_func, args=(runner,) + ) + self._update_thread.start() def on_run_complete( self, runner: Runner, root: Experiment ) -> None: + self._stop_update = True self._maybe_update_summary(runner, force=True) def on_run_abort( @@ -68,8 +80,20 @@ def on_run_abort( root: Experiment, error: BaseException ) -> None: + self._stop_update = True self._maybe_update_summary(runner, force=True) + def _update_thread_func(self, runner: Runner): + while not self._stop_update: + self._maybe_update_summary(runner, background=False) + for leaf in runner.current_run.experiment.leaf_nodes: + if leaf.id in self._stop_update_experiment_ids: + continue + self._maybe_update_experiment_html(runner, leaf, background=False) + if leaf.progress.is_stopped: + self._stop_update_experiment_ids.add(leaf.id) + time.sleep(5) + def on_experiment_start( self, runner: Runner, @@ -101,7 +125,11 @@ def on_example_complete( self._maybe_update_experiment_html(runner, experiment) self._maybe_update_summary(runner) - def _maybe_update_summary(self, runner: Runner, force: bool = False) -> None: + def _maybe_update_summary( + self, + runner: Runner, + background: bool = True, + force: bool = False) -> None: """Maybe update the summary of current run.""" run = runner.current_run def _summary(): @@ -115,26 +143,37 @@ def _summary(): ) if force or (time.time() - self._last_summary_time > self.summary_interval): - runner.background_run(_summary) + if background: + runner.background_run(_summary) + else: + _summary() self._last_summary_time = time.time() def _maybe_update_experiment_html( - self, runner: Runner, experiment: Experiment, force: bool = False + self, + runner: Runner, + experiment: Experiment, + force: bool = False, + background: bool = True, ) -> None: def _save(): index_html_path = runner.current_run.output_path_for( experiment, _EVALULATION_DETAIL_FILE ) try: - html = experiment.to_html( - collapse_level=None, - extra_flags=dict( - current_run=runner.current_run, - interactive=False, - card_view=False, - ), - ) - html.save(index_html_path) + with pg.timeit() as t: + html = experiment.to_html( + collapse_level=None, + extra_flags=dict( + current_run=runner.current_run, + interactive=False, + card_view=False, + ), + ) + html.save(index_html_path) + experiment.info( + f'Generated HTML {index_html_path!r} in {t.elapse:.2f} seconds.', + ) except BaseException as e: # pylint: disable=broad-except experiment.error( f'Failed to save HTML {index_html_path!r}. ' @@ -146,7 +185,10 @@ def _save(): time.time() - self._last_experiment_report_time[experiment.id] > self.experiment_report_interval ): - runner.background_run(_save) + if background: + runner.background_run(_save) + else: + _save() self._last_experiment_report_time[experiment.id] = time.time() def _save_example_html(