diff --git a/langfun/core/eval/v2/checkpointing.py b/langfun/core/eval/v2/checkpointing.py index d09b000..1b18419 100644 --- a/langfun/core/eval/v2/checkpointing.py +++ b/langfun/core/eval/v2/checkpointing.py @@ -13,6 +13,7 @@ # limitations under the License. """Checkpointing evaluation runs.""" import threading +import traceback import langfun.core as lf from langfun.core.eval.v2 import example as example_lib @@ -27,6 +28,21 @@ class Checkpointer(experiment_lib.Plugin): """Base class for checkpointing evaluation examples.""" + def on_experiment_start(self, experiment: Experiment): + if experiment.state.evaluated_examples: + experiment.info( + 'Loaded %d examples from checkpoint files. Example IDs: %s' % + ( + len(experiment.state.evaluated_examples), + list(sorted(experiment.state.evaluated_examples.keys())) + ), + ) + else: + experiment.info( + 'No previous evaluated examples are loaded. ' + f'Experiment {experiment.id} starts from scratch.' + ) + class PerExampleCheckpointer(Checkpointer): """Checkpointer that saves each example to a separate file.""" @@ -68,10 +84,11 @@ def _load_state(ckpt_file): _load_state, ckpt_files, max_workers=64, ): if error is not None: - pg.logging.warning( + experiment.warning( 'Failed to load checkpoint file %s: %s. Skipping the file.', ckpt_file, error ) + super().on_experiment_start(experiment) def on_example_complete( self, @@ -80,7 +97,11 @@ def on_example_complete( example: Example, ) -> None: """Saves the example to the checkpoint file.""" - if not example.has_error: + if example.has_error: + experiment.warning( + f'Example {example.id} has error. Skipping checkpointing.' + ) + else: def save_state(example: Example): writer = SequenceWriter( runner.current_run.output_path_for( @@ -91,8 +112,18 @@ def save_state(example: Example): ) ) ) - writer.add(example) - writer.close() + try: + writer.add(example) + writer.close() + experiment.info( + f'Example {example.id} is saved to {writer.path}.', + ) + except BaseException as e: # pylint: disable=broad-except + experiment.error( + f'Failed to save example {example.id} to {writer.path}. ' + f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.', + ) + raise e runner.background_run(save_state, example) def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]: @@ -164,6 +195,7 @@ def on_experiment_start( with self._lock: if self._sequence_writer is not None: self._sequence_writer[experiment.id] = sequence_writer + super().on_experiment_start(experiment) def on_experiment_complete( self, @@ -178,8 +210,12 @@ def on_experiment_complete( if self._sequence_writer is not None: # Make sure the writer is closed without delay so the file will be # available immediately. - self._sequence_writer[experiment.id].close() - del self._sequence_writer[experiment.id] + writer = self._sequence_writer.pop(experiment.id) + writer.close() + experiment.info( + f'{len(experiment.state.evaluated_examples)} examples are ' + f'checkpointed to {writer.path}.' + ) def on_example_complete( self, @@ -189,8 +225,22 @@ def on_example_complete( ) -> None: """Saves the example to the checkpoint file.""" assert experiment.id in self._sequence_writer - if not example.has_error: - runner.background_run(self._sequence_writer[experiment.id].add, example) + if example.has_error: + experiment.warning( + f'Example {example.id} has error. Skipping checkpointing.' + ) + else: + def _save_example(example: Example): + writer = self._sequence_writer[experiment.id] + try: + writer.add(example) + except BaseException as e: # pylint: disable=broad-except + experiment.error( + f'Failed to save example {example.id} to {writer.path}. ' + f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.', + ) + raise e + runner.background_run(_save_example, example) class SequenceWriter: @@ -198,8 +248,13 @@ class SequenceWriter: def __init__(self, path: str): self._lock = threading.Lock() + self._path = path self._sequence_writer = pg.io.open_sequence(path, 'w') + @property + def path(self) -> str: + return self._path + def add(self, example: Example): example_blob = pg.to_json_str( example, diff --git a/langfun/core/eval/v2/evaluation.py b/langfun/core/eval/v2/evaluation.py index 9a1887d..a699d7a 100644 --- a/langfun/core/eval/v2/evaluation.py +++ b/langfun/core/eval/v2/evaluation.py @@ -285,36 +285,43 @@ def _reset(self) -> None: # Evaluation-level logging. # - def _log(self, level: lf.logging.LogLevel, message: str, **kwargs): + def _log(self, log_func, level: lf.logging.LogLevel, message: str, **kwargs): + # Write to external logging system. + log_message = f'{self.id}: {message}' + if kwargs: + log_message = f'{log_message} (metadata: {kwargs!r})' + log_func(log_message) + + # Add to experiment log history. + log_entry = lf.logging.LogEntry( + level=level, + time=datetime.datetime.now(), + message=message, + metadata=kwargs, + ) with self._log_lock: - self._log_entries.append( - lf.logging.LogEntry( - level=level, - time=datetime.datetime.now(), - message=message, - metadata=kwargs, - ) - ) + self._log_entries.append(log_entry) def debug(self, message: str, **kwargs): """Logs a debug message to the session.""" - self._log('debug', message, **kwargs) + self._log(pg.logging.debug, 'debug', message, **kwargs) def info(self, message: str, **kwargs): """Logs an info message to the session.""" - self._log('info', message, **kwargs) + self._log(pg.logging.info, 'info', message, **kwargs) def warning(self, message: str, **kwargs): """Logs a warning message to the session.""" - self._log('warning', message, **kwargs) + self._log(pg.logging.warning, 'warning', message, **kwargs) def error(self, message: str, **kwargs): """Logs an error message to the session.""" - self._log('error', message, **kwargs) + self._log(pg.logging.error, 'error', message, **kwargs) def fatal(self, message: str, **kwargs): """Logs a fatal message to the session.""" - self._log('fatal', message, **kwargs) + # We use error level for fatal message, which does not trigger assertion. + self._log(pg.logging.error, 'fatal', message, **kwargs) # # HTML views. diff --git a/langfun/core/eval/v2/experiment.py b/langfun/core/eval/v2/experiment.py index 4f18917..f9b86fe 100644 --- a/langfun/core/eval/v2/experiment.py +++ b/langfun/core/eval/v2/experiment.py @@ -959,6 +959,14 @@ def on_experiment_complete( ) -> None: """Called when an experiment (both leaf and non-leaf) is complete.""" + def on_experiment_abort( + self, + runner: Runner, + experiment: Experiment, + error: BaseException, + ) -> None: + """Called when an experiment (both leaf and non-leaf) is aborted.""" + def on_example_start( self, runner: Runner, diff --git a/langfun/core/eval/v2/reporting.py b/langfun/core/eval/v2/reporting.py index 64fcd2e..ed2c6eb 100644 --- a/langfun/core/eval/v2/reporting.py +++ b/langfun/core/eval/v2/reporting.py @@ -14,6 +14,7 @@ """Reporting evaluation results.""" import time +import traceback from typing import Annotated from langfun.core.eval.v2 import example as example_lib @@ -61,6 +62,14 @@ def on_run_complete( ) -> None: self._maybe_update_summary(runner, force=True) + def on_run_abort( + self, + runner: Runner, + root: Experiment, + error: BaseException + ) -> None: + self._maybe_update_summary(runner, force=True) + def on_experiment_start( self, runner: Runner, @@ -75,6 +84,16 @@ def on_experiment_complete( if experiment.is_leaf: self._maybe_update_experiment_html(runner, experiment, force=True) + def on_experiment_abort( + self, + runner: Runner, + experiment: Experiment, + error: BaseException + ) -> None: + del error + assert experiment.is_leaf + self._maybe_update_experiment_html(runner, experiment, force=True) + def on_example_complete( self, runner: Runner, experiment: Experiment, example: Example ): @@ -103,19 +122,26 @@ def _maybe_update_experiment_html( self, runner: Runner, experiment: Experiment, force: bool = False ) -> None: def _save(): - html = experiment.to_html( - collapse_level=None, - extra_flags=dict( - current_run=runner.current_run, - interactive=False, - card_view=False, - ), - ) - html.save( - runner.current_run.output_path_for( - experiment, _EVALULATION_DETAIL_FILE - ) + index_html_path = runner.current_run.output_path_for( + experiment, _EVALULATION_DETAIL_FILE ) + try: + html = experiment.to_html( + collapse_level=None, + extra_flags=dict( + current_run=runner.current_run, + interactive=False, + card_view=False, + ), + ) + html.save(index_html_path) + except BaseException as e: # pylint: disable=broad-except + experiment.error( + f'Failed to save HTML {index_html_path!r}. ' + f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.', + ) + raise e + if force or ( time.time() - self._last_experiment_report_time[experiment.id] > self.experiment_report_interval @@ -128,17 +154,24 @@ def _save_example_html( ) -> None: """Saves the example.""" def _save(): - html = example.to_html( - collapse_level=None, - enable_summary_tooltip=False, - extra_flags=dict( - # For properly rendering the next link. - num_examples=getattr(experiment, 'num_examples', None) - ), - ) - html.save( - runner.current_run.output_path_for( - experiment, f'{example.id}.html' - ) - ) + try: + html = example.to_html( + collapse_level=None, + enable_summary_tooltip=False, + extra_flags=dict( + # For properly rendering the next link. + num_examples=getattr(experiment, 'num_examples', None) + ), + ) + html.save( + runner.current_run.output_path_for( + experiment, f'{example.id}.html' + ) + ) + except BaseException as e: # pylint: disable=broad-except + experiment.error( + f'Failed to save HTML {example.id}.html. ' + f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.', + ) + raise e runner.background_run(_save) diff --git a/langfun/core/eval/v2/runners.py b/langfun/core/eval/v2/runners.py index 6bc4656..caeae77 100644 --- a/langfun/core/eval/v2/runners.py +++ b/langfun/core/eval/v2/runners.py @@ -18,6 +18,7 @@ import random import threading import time +import traceback from typing import Any, Annotated, Callable, Iterator from langfun import core as lf @@ -120,9 +121,14 @@ def on_experiment_start(self, experiment: Experiment) -> None: # Start the progress of the evaluation. if experiment.is_leaf: assert isinstance(experiment, Evaluation) - experiment.progress.start( - total=(len(self.current_run.example_ids) - if self.current_run.example_ids else experiment.num_examples) + num_examples_to_evaluate = ( + len(self.current_run.example_ids) + if self.current_run.example_ids else experiment.num_examples + ) + experiment.progress.start(total=num_examples_to_evaluate) + experiment.info( + 'Starting evaluation %s with %d examples to evaluate.' + % (experiment.id, num_examples_to_evaluate) ) else: experiment.progress.start(total=len(experiment.leaf_nodes)) @@ -144,8 +150,7 @@ def on_experiment_skipped(self, experiment: Experiment) -> None: # Only leaf evaluations will trigger the complete notification of the # ancestors. - if experiment.is_leaf: - self._update_ancestor_progresses(experiment) + self._update_ancestor_progresses(experiment) def on_experiment_complete(self, experiment: Experiment) -> None: """Called when an evaluation is complete.""" @@ -160,6 +165,35 @@ def on_experiment_complete(self, experiment: Experiment) -> None: # ancestors. if experiment.is_leaf: self._update_ancestor_progresses(experiment) + self._log_experiment_completion(experiment) + + def _log_experiment_completion(self, experiment: Experiment): + example_ids = ( + self.current_run.example_ids if self.current_run.example_ids else + list(range(1, experiment.num_examples + 1)) + ) + num_from_checkpoint, num_processed = 0, 0 + for example_id in example_ids: + example = experiment.state.get(example_id) + if example.newly_processed: + num_processed += 1 + else: + num_from_checkpoint += 1 + experiment.info( + f'{experiment.id} completed with {num_from_checkpoint + num_processed} ' + f'examples evaluated ({num_from_checkpoint} from checkpoint, ' + f'{num_processed} newly processed).' + ) + + def on_experiment_abort( + self, experiment: Experiment, error: BaseException) -> None: + """Called when an evaluation is complete.""" + assert experiment.is_leaf + experiment.fatal(f'{error}\n\n{traceback.format_exc()}') + + # Notify the plugins of the experiment abort. + for plugin in self._all_plugins(experiment): + plugin.on_experiment_abort(self, experiment, error) def _update_ancestor_progresses(self, experiment: Experiment): """Updates the progresses of the parent nodes of the experiment.""" @@ -270,31 +304,36 @@ def _run(self, evaluations: list[Evaluation]) -> None: def run_evaluation(self, evaluation: Evaluation) -> None: """Runs the evaluation.""" - self.on_experiment_start(evaluation) - - per_evaluation_settings = {} - cache = None - if self.current_run.use_cache == 'per_dataset': - cache = self._load_or_create_cache(evaluation) - per_evaluation_settings['cache'] = cache - - with lf.use_settings(**per_evaluation_settings): - if self.current_run.example_ids is None: - items = ( - Example(id=i + 1, input=ex) for i, ex in enumerate( - evaluation.example_inputs) - ) - else: - items = ( - Example( - id=example_id, input=evaluation.example_input_by_id(example_id) - ) for example_id in self.current_run.example_ids - ) - self._evaluate_items(evaluation, items) - - if cache: - self.background_run(cache.save) - self.on_experiment_complete(evaluation) + try: + self.on_experiment_start(evaluation) + + per_evaluation_settings = {} + cache = None + if self.current_run.use_cache == 'per_dataset': + cache = self._load_or_create_cache(evaluation) + per_evaluation_settings['cache'] = cache + + with lf.use_settings(**per_evaluation_settings): + if self.current_run.example_ids is None: + items = ( + Example(id=i + 1, input=ex) for i, ex in enumerate( + evaluation.example_inputs) + ) + else: + items = ( + Example( + id=example_id, + input=evaluation.example_input_by_id(example_id) + ) for example_id in self.current_run.example_ids + ) + self._evaluate_items(evaluation, items) + + if cache: + self.background_run(cache.save) + self.on_experiment_complete(evaluation) + except BaseException as e: # pylint: disable=broad-except + self.on_experiment_abort(evaluation, e) + raise e @abc.abstractmethod def _evaluate_items( @@ -410,9 +449,7 @@ def _run_group(evaluation_group: list[Evaluation]): groups.values(), max_workers=max(64, len(groups)), timeout=self.timeout, - silence_on_errors=( - None if self.current_run.raise_if_has_error else BaseException - ) + silence_on_errors=None, ): pass @@ -437,8 +474,6 @@ def _evaluate_item(item: Example): items, max_workers=evaluation.max_workers, timeout=self.timeout, - silence_on_errors=( - None if self.current_run.raise_if_has_error else BaseException - ) + silence_on_errors=None, ): pass