Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lf.eval.v2 to support PerExampleCheckpointer. #373

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion langfun/core/eval/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@
from langfun.core.eval.v2 import metrics

from langfun.core.eval.v2.experiment import Plugin

from langfun.core.eval.v2.experiment import Runner
from langfun.core.eval.v2 import runners

# Plugins
from langfun.core.eval.v2.checkpointing import BulkCheckpointer
from langfun.core.eval.v2.checkpointing import PerExampleCheckpointer
from langfun.core.eval.v2.reporting import HtmlReporter


# pylint: enable=g-bad-import-order
# pylint: enable=g-importing-member
112 changes: 96 additions & 16 deletions langfun/core/eval/v2/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Checkpointing evaluation runs."""
import threading

import langfun.core as lf
from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import experiment as experiment_lib
import pyglove as pg
Expand All @@ -24,21 +25,100 @@


class Checkpointer(experiment_lib.Plugin):
"""Plugin for checkpointing evaluation runs."""
"""Base class for checkpointing evaluation examples."""


class PerExampleCheckpointer(Checkpointer):
"""Checkpointer that saves each example to a separate file."""

checkpoint_filename: str = 'checkpoint.bagz'

def _on_bound(self):
super()._on_bound()
prefix, ext = self._file_prefix_and_ext(self.checkpoint_filename)
self._checkpoint_file_prefix = prefix
self._checkpoint_file_ext = ext

def on_experiment_start(
self,
runner: Runner,
experiment: Experiment,
) -> None:
"""Creates the checkpoint file."""
if not experiment.is_leaf:
return

# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
def _load_state(ckpt_file):
experiment.load_state(ckpt_file)

experiment_dir = runner.current_run.input_dir(experiment)
if pg.io.path_exists(experiment_dir):
ckpt_files = [
runner.current_run.input_path_for(experiment, filename)
for filename in pg.io.listdir(experiment_dir)
if filename.startswith(self._checkpoint_file_prefix)
and filename.endswith(self._checkpoint_file_ext)
]
else:
ckpt_files = []

for ckpt_file, _, error in lf.concurrent_map(
_load_state, ckpt_files, max_workers=64,
):
if error is not None:
pg.logging.warning(
'Failed to load checkpoint file %s: %s. Skipping the file.',
ckpt_file, error
)

def on_example_complete(
self,
runner: Runner,
experiment: Experiment,
example: Example,
) -> None:
"""Saves the example to the checkpoint file."""
if not example.has_error:
def save_state(example: Example):
writer = SequenceWriter(
runner.current_run.output_path_for(
experiment,
(
f'{self._checkpoint_file_prefix}_{example.id}'
f'{self._checkpoint_file_ext}'
)
)
)
writer.add(example)
del writer
runner.background_run(save_state, example)

def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]:
ext_index = filename.rfind('.')
if ext_index == -1:
return filename, ''
else:
return filename[:ext_index], filename[ext_index:]


class BulkCheckpointer(Checkpointer):
"""Checkpointer that saves all examples to a single file."""

checkpoint_filename: str = 'checkpoint.bagz'

def _on_bound(self):
super()._on_bound()
self._lock = threading.Lock()
self._state_writer = None
self._sequence_writer = None

def on_run_start(
self,
runner: Runner,
root: Experiment,
) -> None:
self._state_writer = {}
self._sequence_writer = {}

def on_run_abort(
self,
Expand All @@ -47,16 +127,16 @@ def on_run_abort(
error: BaseException
) -> None:
with self._lock:
if self._state_writer is not None:
self._state_writer.clear()
if self._sequence_writer is not None:
self._sequence_writer.clear()

def on_run_complete(
self,
runner: Runner,
root: Experiment,
) -> None:
with self._lock:
assert self._state_writer is not None and not self._state_writer
assert self._sequence_writer is not None and not self._sequence_writer

def on_experiment_start(
self,
Expand All @@ -74,14 +154,14 @@ def on_experiment_start(
),
raise_if_not_exist=False
)
state_writer = StateWriter(
sequence_writer = SequenceWriter(
runner.current_run.output_path_for(
experiment, self.checkpoint_filename
)
)
with self._lock:
if self._state_writer is not None:
self._state_writer[experiment.id] = state_writer
if self._sequence_writer is not None:
self._sequence_writer[experiment.id] = sequence_writer

def on_experiment_complete(
self,
Expand All @@ -91,10 +171,10 @@ def on_experiment_complete(
"""Closes the checkpoint file."""
if not experiment.is_leaf:
return
assert experiment.id in self._state_writer
assert experiment.id in self._sequence_writer
with self._lock:
if self._state_writer is not None:
del self._state_writer[experiment.id]
if self._sequence_writer is not None:
del self._sequence_writer[experiment.id]

def on_example_complete(
self,
Expand All @@ -103,13 +183,13 @@ def on_example_complete(
example: Example,
) -> None:
"""Saves the example to the checkpoint file."""
assert experiment.id in self._state_writer
assert experiment.id in self._sequence_writer
if not example.has_error:
runner.background_run(self._state_writer[experiment.id].add, example)
runner.background_run(self._sequence_writer[experiment.id].add, example)


class StateWriter:
"""Thread safe state writer."""
class SequenceWriter:
"""Thread safe sequence writer."""

def __init__(self, path: str):
self._lock = threading.Lock()
Expand Down
49 changes: 41 additions & 8 deletions langfun/core/eval/v2/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@
Example = example_lib.Example


class StateWriterTest(unittest.TestCase):
class SequenceWriterTest(unittest.TestCase):

def test_basic(self):
file = os.path.join(tempfile.gettempdir(), 'test.jsonl')
writer = checkpointing.StateWriter(file)
writer = checkpointing.SequenceWriter(file)
example = Example(id=1, input=pg.Dict(x=1), output=2)
writer.add(example)
del writer
self.assertTrue(pg.io.path_exists(file))

def test_error_handling(self):
file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl')
writer = checkpointing.StateWriter(file)
writer = checkpointing.SequenceWriter(file)
writer.add(Example(id=1, input=pg.Dict(x=1), output=2))

def f():
Expand All @@ -52,17 +52,50 @@ def f():
self.assertEqual(len(list(iter(f))), 1)


class CheckpointingTest(unittest.TestCase):
class PerExampleCheckpointerTest(unittest.TestCase):

def test_checkpointing(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_checkpointing')
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
experiment = test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
checkpointer = checkpointing.Checkpointer(checkpoint_filename)
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
run = experiment.run(
root_dir, 'new', runner='sequential', plugins=[checkpointer]
)
self.assertEqual(len(checkpointer._state_writer), 0)
num_processed = {}
for leaf in experiment.leaf_nodes:
for i in range(leaf.num_examples):
example = leaf.state.get(i + 1)
ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
if example.has_error:
self.assertFalse(pg.io.path_exists(ckpt))
else:
self.assertTrue(pg.io.path_exists(ckpt))
with pg.io.open_sequence(ckpt) as f:
self.assertEqual(len(list(iter(f))), 1)
if leaf.id not in num_processed:
self.assertEqual(leaf.progress.num_skipped, 0)
num_processed[leaf.id] = leaf.progress.num_processed

# Run again, should skip existing.
_ = experiment.run(
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
)
for leaf in experiment.leaf_nodes:
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])


class BulkCheckpointerTest(unittest.TestCase):

def test_checkpointing(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
experiment = test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename)
run = experiment.run(
root_dir, 'new', runner='sequential', plugins=[checkpointer]
)
self.assertEqual(len(checkpointer._sequence_writer), 0)
num_processed = {}
for leaf in experiment.leaf_nodes:
ckpt = run.output_path_for(leaf, checkpoint_filename)
Expand All @@ -80,7 +113,7 @@ def test_checkpointing(self):
_ = experiment.run(
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
)
self.assertEqual(len(checkpointer._state_writer), 0)
self.assertEqual(len(checkpointer._sequence_writer), 0)
for leaf in experiment.leaf_nodes:
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])

Expand Down
2 changes: 1 addition & 1 deletion langfun/core/eval/v2/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class RunnerBase(Runner):
] = False

plugins = [
checkpointing.Checkpointer(),
checkpointing.BulkCheckpointer(),
reporting.HtmlReporter(),
]

Expand Down
Loading