Skip to content

Commit

Permalink
Refining lf.eval.v2.
Browse files Browse the repository at this point in the history
- Rename `eval/v2/test_helper.py` to `eval/v2/eval_test_helper.py`, to prevent it from being tracked by pytest.
- ParallelRunner: Avoid race condition on io_pool between submission and shutdown.

PiperOrigin-RevId: 708369332
  • Loading branch information
daiyip authored and langfun authors committed Dec 20, 2024
1 parent 74ee960 commit a5d2bbc
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 36 deletions.
12 changes: 10 additions & 2 deletions langfun/core/eval/v2/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def on_experiment_start(

# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
if runner.current_run.input_root != runner.current_run.output_root:
experiment.info(
f'Warm starting from directory: {runner.current_run.input_root}.'
)
def _load_state(ckpt_file):
experiment.load_state(ckpt_file)

Expand All @@ -85,8 +89,8 @@ def _load_state(ckpt_file):
):
if error is not None:
experiment.warning(
'Failed to load checkpoint file %s: %s. Skipping the file.',
ckpt_file, error
f'Failed to load checkpoint file {ckpt_file}: {error}. '
'Skipping the file.'
)
super().on_experiment_start(experiment)

Expand Down Expand Up @@ -181,6 +185,10 @@ def on_experiment_start(
return
# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
if runner.current_run.input_root != runner.current_run.output_root:
experiment.info(
f'Warm starting from directory: {runner.current_run.input_root}.'
)
experiment.load_state(
runner.current_run.input_path_for(
experiment, self.checkpoint_filename
Expand Down
6 changes: 3 additions & 3 deletions langfun/core/eval/v2/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import unittest

from langfun.core.eval.v2 import checkpointing
from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
from langfun.core.eval.v2 import test_helper
import pyglove as pg

Example = example_lib.Example
Expand Down Expand Up @@ -56,7 +56,7 @@ class PerExampleCheckpointerTest(unittest.TestCase):

def test_checkpointing(self):
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
run = experiment.run(
Expand Down Expand Up @@ -89,7 +89,7 @@ class BulkCheckpointerTest(unittest.TestCase):

def test_checkpointing(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename)
run = experiment.run(
Expand Down
File renamed without changes.
35 changes: 21 additions & 14 deletions langfun/core/eval/v2/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import tempfile
import unittest

from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import evaluation as evaluation_lib
from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import experiment as experiment_lib

from langfun.core.eval.v2 import test_helper

import pyglove as pg

Example = example_lib.Example
Expand All @@ -32,17 +31,23 @@
class EvaluationTest(unittest.TestCase):

def test_hyper_evaluation(self):
exp = test_helper.TestEvaluation(
lm=test_helper.TestLLM(offset=pg.oneof(range(3)))
exp = eval_test_helper.TestEvaluation(
lm=eval_test_helper.TestLLM(offset=pg.oneof(range(3)))
)
self.assertFalse(exp.is_leaf)
self.assertTrue(
pg.eq(
exp.children,
[
test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=0)),
test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=1)),
test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=2)),
eval_test_helper.TestEvaluation(
lm=eval_test_helper.TestLLM(offset=0)
),
eval_test_helper.TestEvaluation(
lm=eval_test_helper.TestLLM(offset=1)
),
eval_test_helper.TestEvaluation(
lm=eval_test_helper.TestLLM(offset=2)
),
]
)
)
Expand All @@ -57,19 +62,21 @@ def test_hyper_evaluation(self):
)

def test_input(self):
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
self.assertEqual(exp.num_examples, 10)
exp = test_helper.TestEvaluation(inputs=test_helper.test_inputs(None))
exp = eval_test_helper.TestEvaluation(
inputs=eval_test_helper.test_inputs(None)
)
self.assertEqual(exp.num_examples, 20)
@pg.functor
def my_inputs():
yield pg.Dict(x=1, y=2)
yield pg.Dict(x=3, y=4)
exp = test_helper.TestEvaluation(inputs=my_inputs())
exp = eval_test_helper.TestEvaluation(inputs=my_inputs())
self.assertEqual(exp.num_examples, 2)

def test_evaluate(self):
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
example = exp.evaluate(Example(id=3))
self.assertIs(exp.state.get(3), example)
self.assertTrue(example.newly_processed)
Expand All @@ -85,7 +92,7 @@ def test_evaluate(self):
self.assertIsNotNone(example.start_time)
self.assertIsNotNone(example.end_time)

exp = test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=1))
exp = eval_test_helper.TestEvaluation(lm=eval_test_helper.TestLLM(offset=1))
example = exp.evaluate(3)
self.assertTrue(example.newly_processed)
self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
Expand All @@ -109,7 +116,7 @@ def test_evaluate_with_state(self):
pg.io.mkdirs(eval_dir, exist_ok=True)
state_file = os.path.join(eval_dir, 'state.jsonl')
with pg.io.open_sequence(state_file, 'w') as f:
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
example = exp.evaluate(3)
self.assertTrue(example.newly_processed)
self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
Expand All @@ -132,7 +139,7 @@ def test_evaluate_with_state(self):
self.assertEqual(example.usage_summary.uncached.total.num_requests, 0)

def test_html_view(self):
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
exp.debug('debug message')
exp.info('info message')
exp.warning('warning message', x=1)
Expand Down
8 changes: 4 additions & 4 deletions langfun/core/eval/v2/progress_tracking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import unittest

from langfun.core import console as lf_console
from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import progress_tracking # pylint: disable=unused-import
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
from langfun.core.eval.v2 import test_helper
import pyglove as pg


Expand All @@ -35,7 +35,7 @@ def display(x):
display=display
)
root_dir = os.path.join(tempfile.gettempdir(), 'test_html_progress_tracker')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
_ = experiment.run(root_dir, 'new', plugins=[])
self.assertIsInstance(result['view'], pg.Html)
lf_console._notebook = None
Expand All @@ -45,7 +45,7 @@ class TqdmProgressTrackerTest(unittest.TestCase):

def test_basic(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_tqdm_progress_tracker')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
string_io = io.StringIO()
with contextlib.redirect_stderr(string_io):
_ = experiment.run(root_dir, 'new', plugins=[])
Expand All @@ -55,7 +55,7 @@ def test_with_example_ids(self):
root_dir = os.path.join(
tempfile.gettempdir(), 'test_tqdm_progress_tracker_with_example_ids'
)
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
string_io = io.StringIO()
with contextlib.redirect_stderr(string_io):
_ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[])
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/eval/v2/reporting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
import tempfile
import unittest

from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import reporting
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
from langfun.core.eval.v2 import test_helper
import pyglove as pg


class ReportingTest(unittest.TestCase):

def test_reporting(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_reporting')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
reporter = reporting.HtmlReporter()
run = experiment.run(root_dir, 'new', plugins=[reporter])
pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
Expand Down
10 changes: 8 additions & 2 deletions langfun/core/eval/v2/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _on_bound(self):
with pg.notify_on_change(False):
self.plugins.append(progress_tracking.progress_tracker(self.tqdm))

self._io_pool_lock = threading.Lock()
self._io_pool = concurrent.futures.ThreadPoolExecutor(max_workers=16)
# TODO(daiyip): render background errors.
self._background_last_error = None
Expand All @@ -76,7 +77,10 @@ def _background_run(*args, **kwargs):
func(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
self._background_last_error = e
self._io_pool.submit(_background_run, *args, **kwargs)

with self._io_pool_lock:
if self._io_pool is not None:
self._io_pool.submit(_background_run, *args, **kwargs)

def _all_plugins(self, experiment: Experiment) -> Iterator[Plugin]:
"""Returns all plugins for the experiment."""
Expand Down Expand Up @@ -296,7 +300,9 @@ def run(self) -> None:
self.background_run(cache.save)

# Wait for the background tasks to finish.
self._io_pool.shutdown(wait=True)
with self._io_pool_lock:
self._io_pool, io_pool = None, self._io_pool
io_pool.shutdown(wait=True)

@abc.abstractmethod
def _run(self, evaluations: list[Evaluation]) -> None:
Expand Down
19 changes: 10 additions & 9 deletions langfun/core/eval/v2/runners_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from typing import Any
import unittest

from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import experiment as experiment_lib
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
from langfun.core.eval.v2 import test_helper

import pyglove as pg


Expand Down Expand Up @@ -101,7 +102,7 @@ def assert_same_list(self, actual: list[Any], expected: list[Any]):

def test_basic(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
run = exp.run(root_dir, runner='sequential', plugins=[plugin])

Expand Down Expand Up @@ -143,7 +144,7 @@ def test_basic(self):

def test_raise_if_has_error(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
exp.run(
root_dir, runner='sequential', plugins=[], raise_if_has_error=True
Expand All @@ -154,7 +155,7 @@ def test_raise_if_has_error(self):

def test_example_ids(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
plugin = TestPlugin()
_ = exp.run(
root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
Expand All @@ -164,7 +165,7 @@ def test_example_ids(self):

def test_filter(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')

_ = exp.run(
Expand Down Expand Up @@ -193,7 +194,7 @@ def test_inputs(num_examples: int = 10):
) for i in range(num_examples)
]

exp = test_helper.TestEvaluation(
exp = eval_test_helper.TestEvaluation(
inputs=test_inputs(num_examples=pg.oneof([2, 4]))
)
# Global cache.
Expand Down Expand Up @@ -234,7 +235,7 @@ class ParallelRunnerTest(RunnerTest):

def test_parallel_runner(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
run = exp.run(root_dir, runner='parallel', plugins=[plugin])

Expand Down Expand Up @@ -274,7 +275,7 @@ def test_parallel_runner(self):

def test_concurrent_startup_delay(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(
tempfile.gettempdir(), 'test_concurrent_startup_delay'
)
Expand All @@ -290,7 +291,7 @@ class DebugRunnerTest(RunnerTest):

def test_debug_runner(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
run = exp.run(root_dir, runner='debug', plugins=[plugin])

Expand Down

0 comments on commit a5d2bbc

Please sign in to comment.