Skip to content

Commit

Permalink
Enhancements to lf.eval.v2
Browse files Browse the repository at this point in the history
1) Enable periodical update of essential HTML files: summary.html will be updated every 1 minute, experiment index.html will be updated every 2 minutes.

2) Enrich logging for checkpoint loading and writing.

PiperOrigin-RevId: 708517113
  • Loading branch information
daiyip authored and langfun authors committed Dec 21, 2024
1 parent a5d2bbc commit 9b8f7ec
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 120 deletions.
263 changes: 161 additions & 102 deletions langfun/core/eval/v2/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpointing evaluation runs."""
import abc
import threading
import traceback

Expand All @@ -28,21 +29,74 @@
class Checkpointer(experiment_lib.Plugin):
"""Base class for checkpointing evaluation examples."""

def on_experiment_start(self, experiment: Experiment):
def on_experiment_start(
self,
runner: Runner,
experiment: Experiment
) -> None:
if not experiment.is_leaf:
return

# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
if runner.current_run.input_root != runner.current_run.output_root:
experiment.info(
f'Warm starting from directory: {runner.current_run.input_root}.'
)
self._load_experiment(runner, experiment)

if experiment.state.evaluated_examples:
loaded_example_ids = list(
sorted(experiment.state.evaluated_examples.keys())
)
example_ids_to_evaluate = (
set(runner.current_run.example_ids) if runner.current_run.example_ids
else set(range(1, experiment.num_examples + 1))
)
example_ids_to_evaluate -= set(loaded_example_ids)

experiment.info(
'Loaded %d examples from checkpoint files. Example IDs: %s' %
(
len(experiment.state.evaluated_examples),
list(sorted(experiment.state.evaluated_examples.keys()))
),
f'{len(experiment.state.evaluated_examples)} examples have been '
'loaded from checkpoint files. Their outputs will be used '
f'for recomputing metrics. Example IDs: {loaded_example_ids}'
)
experiment.info(
f'{len(example_ids_to_evaluate)} examples will be processed from '
f'scratch. Example IDs: {list(sorted(example_ids_to_evaluate))}'
)
else:
experiment.info(
'No previous evaluated examples are loaded. '
'No examples are loaded from checkpoint files. '
f'Experiment {experiment.id} starts from scratch.'
)

def on_example_complete(
self,
runner: Runner,
experiment: Experiment,
example: Example,
) -> None:
"""Saves the example to the checkpoint file."""
if example.has_error:
experiment.warning(
f'Example {example.id} has error. Skipping checkpointing.'
)
else:
self._save_example(runner, experiment, example)

@abc.abstractmethod
def _load_experiment(self, runner: Runner, experiment: Experiment) -> None:
"""Loads the experiment state from checkpoint files."""

@abc.abstractmethod
def _save_example(
self,
runner: Runner,
experiment: Experiment,
example: Example,
) -> None:
"""Saves an evaluated example."""


class PerExampleCheckpointer(Checkpointer):
"""Checkpointer that saves each example to a separate file."""
Expand All @@ -55,80 +109,86 @@ def _on_bound(self):
self._checkpoint_file_prefix = prefix
self._checkpoint_file_ext = ext

def on_experiment_start(
def _load_experiment(
self,
runner: Runner,
experiment: Experiment,
) -> None:
"""Creates the checkpoint file."""
if not experiment.is_leaf:
return
experiment_dir = runner.current_run.input_dir(experiment)
if pg.io.path_exists(experiment_dir):
ckpt_files = [
runner.current_run.input_path_for(experiment, filename)
for filename in pg.io.listdir(experiment_dir)
if filename.startswith(self._checkpoint_file_prefix)
and filename.endswith(self._checkpoint_file_ext)
]
else:
ckpt_files = []

# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
if runner.current_run.input_root != runner.current_run.output_root:
experiment.info(
f'Warm starting from directory: {runner.current_run.input_root}.'
experiment.info(f'Found {len(ckpt_files)} checkpoint files to load.')

# Load the checkpoint files in parallel.
context = dict(counter=0, counter_lock=threading.Lock())
def _load_state(ckpt_file):
error = None
with pg.timeit() as t:
try:
experiment.load_state(ckpt_file)
except BaseException as e: # pylint: disable=broad-except
error = e
finally:
with context['counter_lock']:
context['counter'] += 1

progress_str = f'{context["counter"]}/{len(ckpt_files)}'
if error is None:
experiment.info(
f'Loaded checkpoint file {ckpt_file} in {t.elapse:.2f} '
f'seconds. ({progress_str})'
)
else:
experiment.warning(
f'Failed to load checkpoint file {ckpt_file}: {error}. '
f'Skipping the file. ({progress_str})'
)

_ = list(
lf.concurrent_map(
_load_state, ckpt_files, max_workers=16, silence_on_errors=None
)
def _load_state(ckpt_file):
experiment.load_state(ckpt_file)

experiment_dir = runner.current_run.input_dir(experiment)
if pg.io.path_exists(experiment_dir):
ckpt_files = [
runner.current_run.input_path_for(experiment, filename)
for filename in pg.io.listdir(experiment_dir)
if filename.startswith(self._checkpoint_file_prefix)
and filename.endswith(self._checkpoint_file_ext)
]
else:
ckpt_files = []

for ckpt_file, _, error in lf.concurrent_map(
_load_state, ckpt_files, max_workers=64,
):
if error is not None:
experiment.warning(
f'Failed to load checkpoint file {ckpt_file}: {error}. '
'Skipping the file.'
)
super().on_experiment_start(experiment)
)

def on_example_complete(
def _save_example(
self,
runner: Runner,
experiment: Experiment,
example: Example,
) -> None:
"""Saves the example to the checkpoint file."""
if example.has_error:
experiment.warning(
f'Example {example.id} has error. Skipping checkpointing.'
def save_state(example: Example):
writer = SequenceWriter(
runner.current_run.output_path_for(
experiment,
(
f'{self._checkpoint_file_prefix}_{example.id}'
f'{self._checkpoint_file_ext}'
)
)
)
else:
def save_state(example: Example):
writer = SequenceWriter(
runner.current_run.output_path_for(
experiment,
(
f'{self._checkpoint_file_prefix}_{example.id}'
f'{self._checkpoint_file_ext}'
)
)
try:
writer.add(example)
writer.close()
experiment.info(
f'Example {example.id} saved to {writer.path}.',
)
try:
writer.add(example)
writer.close()
experiment.info(
f'Example {example.id} is saved to {writer.path}.',
)
except BaseException as e: # pylint: disable=broad-except
experiment.error(
f'Failed to save example {example.id} to {writer.path}. '
f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
)
raise e
runner.background_run(save_state, example)
except BaseException as e: # pylint: disable=broad-except
experiment.error(
f'Failed to save example {example.id} to {writer.path}. '
f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
)
raise e
runner.background_run(save_state, example)

def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]:
ext_index = filename.rfind('.')
Expand Down Expand Up @@ -180,30 +240,31 @@ def on_experiment_start(
runner: Runner,
experiment: Experiment,
) -> None:
"""Creates the checkpoint file."""
if not experiment.is_leaf:
return
# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
if runner.current_run.input_root != runner.current_run.output_root:
experiment.info(
f'Warm starting from directory: {runner.current_run.input_root}.'
)
experiment.load_state(
runner.current_run.input_path_for(
super().on_experiment_start(runner, experiment)

# Prepare the sequence writer for the experiment.
if experiment.is_leaf:
sequence_writer = SequenceWriter(
runner.current_run.output_path_for(
experiment, self.checkpoint_filename
),
raise_if_not_exist=False
)
)
sequence_writer = SequenceWriter(
runner.current_run.output_path_for(
with self._lock:
if self._sequence_writer is not None:
self._sequence_writer[experiment.id] = sequence_writer

def _load_experiment(
self,
runner: Runner,
experiment: Experiment,
) -> None:
"""Creates the checkpoint file."""
experiment.load_state(
runner.current_run.input_path_for(
experiment, self.checkpoint_filename
)
),
raise_if_not_exist=False
)
with self._lock:
if self._sequence_writer is not None:
self._sequence_writer[experiment.id] = sequence_writer
super().on_experiment_start(experiment)

def on_experiment_complete(
self,
Expand All @@ -225,30 +286,28 @@ def on_experiment_complete(
f'checkpointed to {writer.path}.'
)

def on_example_complete(
def _save_example(
self,
runner: Runner,
experiment: Experiment,
example: Example,
) -> None:
"""Saves the example to the checkpoint file."""
assert experiment.id in self._sequence_writer
if example.has_error:
experiment.warning(
f'Example {example.id} has error. Skipping checkpointing.'
)
else:
def _save_example(example: Example):
writer = self._sequence_writer[experiment.id]
try:
writer.add(example)
except BaseException as e: # pylint: disable=broad-except
experiment.error(
f'Failed to save example {example.id} to {writer.path}. '
f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
)
raise e
runner.background_run(_save_example, example)
def _save_example(example: Example):
writer = self._sequence_writer[experiment.id]
try:
writer.add(example)
experiment.info(
f'Example {example.id} added to {writer.path}.',
)
except BaseException as e: # pylint: disable=broad-except
experiment.error(
f'Failed to save example {example.id} to {writer.path}. '
f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
)
raise e
runner.background_run(_save_example, example)


class SequenceWriter:
Expand Down
1 change: 1 addition & 0 deletions langfun/core/eval/v2/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def f():
class PerExampleCheckpointerTest(unittest.TestCase):

def test_checkpointing(self):
pg.defaults.loggers.use_stdout()
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
experiment = eval_test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
Expand Down
8 changes: 4 additions & 4 deletions langfun/core/eval/v2/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
directory (using the ID 'latest'). Users can specify 'new' to start a fresh
run or provide a specific run ID (typically in the format %Y%m%d_%<number>).
Additionally, when initiating a new run, users may specify a `warm_start_from`
ID to restore the experiment’s state from a previous run.
directory to restore the experiment’s state from a previous run.
Examples:
Expand All @@ -97,9 +97,9 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
# Start a new, clean run.
experiment.run(root_dir, 'new')
# Start a new run with a warm start from the previous run located in
# 'run_20241031_1' of the root directory.
experiment.run(root_dir, 'new', warm_start_from='20241031_1')
# Start a new run with a warm start from the another run located at
# '/path/to/another/run' (e.g. /my_expreriment/run_20241031_1).
experiment.run(root_dir, 'new', warm_start_from='/path/to/another/run')
# Resume run '20241031_1', re-running failed examples and recomputing
# metrics as needed.
Expand Down
Loading

0 comments on commit 9b8f7ec

Please sign in to comment.