diff --git a/langfun/core/eval/v2/checkpointing.py b/langfun/core/eval/v2/checkpointing.py index 1b18419..65ca6c9 100644 --- a/langfun/core/eval/v2/checkpointing.py +++ b/langfun/core/eval/v2/checkpointing.py @@ -66,6 +66,10 @@ def on_experiment_start( # For refresh runs, we don't want to load the previous state. if not runner.current_run.refresh: + if runner.current_run.input_root != runner.current_run.output_root: + experiment.info( + f'Warm starting from directory: {runner.current_run.input_root}.' + ) def _load_state(ckpt_file): experiment.load_state(ckpt_file) @@ -85,8 +89,8 @@ def _load_state(ckpt_file): ): if error is not None: experiment.warning( - 'Failed to load checkpoint file %s: %s. Skipping the file.', - ckpt_file, error + f'Failed to load checkpoint file {ckpt_file}: {error}. ' + 'Skipping the file.' ) super().on_experiment_start(experiment) @@ -181,6 +185,10 @@ def on_experiment_start( return # For refresh runs, we don't want to load the previous state. if not runner.current_run.refresh: + if runner.current_run.input_root != runner.current_run.output_root: + experiment.info( + f'Warm starting from directory: {runner.current_run.input_root}.' + ) experiment.load_state( runner.current_run.input_path_for( experiment, self.checkpoint_filename diff --git a/langfun/core/eval/v2/checkpointing_test.py b/langfun/core/eval/v2/checkpointing_test.py index 5425ad4..b1537a5 100644 --- a/langfun/core/eval/v2/checkpointing_test.py +++ b/langfun/core/eval/v2/checkpointing_test.py @@ -16,9 +16,9 @@ import unittest from langfun.core.eval.v2 import checkpointing +from langfun.core.eval.v2 import eval_test_helper from langfun.core.eval.v2 import example as example_lib from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import -from langfun.core.eval.v2 import test_helper import pyglove as pg Example = example_lib.Example @@ -56,7 +56,7 @@ class PerExampleCheckpointerTest(unittest.TestCase): def test_checkpointing(self): root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer') - experiment = test_helper.test_experiment() + experiment = eval_test_helper.test_experiment() checkpoint_filename = 'checkpoint.jsonl' checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename) run = experiment.run( @@ -89,7 +89,7 @@ class BulkCheckpointerTest(unittest.TestCase): def test_checkpointing(self): root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer') - experiment = test_helper.test_experiment() + experiment = eval_test_helper.test_experiment() checkpoint_filename = 'checkpoint.jsonl' checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename) run = experiment.run( diff --git a/langfun/core/eval/v2/test_helper.py b/langfun/core/eval/v2/eval_test_helper.py similarity index 100% rename from langfun/core/eval/v2/test_helper.py rename to langfun/core/eval/v2/eval_test_helper.py diff --git a/langfun/core/eval/v2/evaluation_test.py b/langfun/core/eval/v2/evaluation_test.py index 92331f9..413c844 100644 --- a/langfun/core/eval/v2/evaluation_test.py +++ b/langfun/core/eval/v2/evaluation_test.py @@ -15,12 +15,11 @@ import tempfile import unittest +from langfun.core.eval.v2 import eval_test_helper from langfun.core.eval.v2 import evaluation as evaluation_lib from langfun.core.eval.v2 import example as example_lib from langfun.core.eval.v2 import experiment as experiment_lib -from langfun.core.eval.v2 import test_helper - import pyglove as pg Example = example_lib.Example @@ -32,17 +31,23 @@ class EvaluationTest(unittest.TestCase): def test_hyper_evaluation(self): - exp = test_helper.TestEvaluation( - lm=test_helper.TestLLM(offset=pg.oneof(range(3))) + exp = eval_test_helper.TestEvaluation( + lm=eval_test_helper.TestLLM(offset=pg.oneof(range(3))) ) self.assertFalse(exp.is_leaf) self.assertTrue( pg.eq( exp.children, [ - test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=0)), - test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=1)), - test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=2)), + eval_test_helper.TestEvaluation( + lm=eval_test_helper.TestLLM(offset=0) + ), + eval_test_helper.TestEvaluation( + lm=eval_test_helper.TestLLM(offset=1) + ), + eval_test_helper.TestEvaluation( + lm=eval_test_helper.TestLLM(offset=2) + ), ] ) ) @@ -57,19 +62,21 @@ def test_hyper_evaluation(self): ) def test_input(self): - exp = test_helper.TestEvaluation() + exp = eval_test_helper.TestEvaluation() self.assertEqual(exp.num_examples, 10) - exp = test_helper.TestEvaluation(inputs=test_helper.test_inputs(None)) + exp = eval_test_helper.TestEvaluation( + inputs=eval_test_helper.test_inputs(None) + ) self.assertEqual(exp.num_examples, 20) @pg.functor def my_inputs(): yield pg.Dict(x=1, y=2) yield pg.Dict(x=3, y=4) - exp = test_helper.TestEvaluation(inputs=my_inputs()) + exp = eval_test_helper.TestEvaluation(inputs=my_inputs()) self.assertEqual(exp.num_examples, 2) def test_evaluate(self): - exp = test_helper.TestEvaluation() + exp = eval_test_helper.TestEvaluation() example = exp.evaluate(Example(id=3)) self.assertIs(exp.state.get(3), example) self.assertTrue(example.newly_processed) @@ -85,7 +92,7 @@ def test_evaluate(self): self.assertIsNotNone(example.start_time) self.assertIsNotNone(example.end_time) - exp = test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=1)) + exp = eval_test_helper.TestEvaluation(lm=eval_test_helper.TestLLM(offset=1)) example = exp.evaluate(3) self.assertTrue(example.newly_processed) self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6)) @@ -109,7 +116,7 @@ def test_evaluate_with_state(self): pg.io.mkdirs(eval_dir, exist_ok=True) state_file = os.path.join(eval_dir, 'state.jsonl') with pg.io.open_sequence(state_file, 'w') as f: - exp = test_helper.TestEvaluation() + exp = eval_test_helper.TestEvaluation() example = exp.evaluate(3) self.assertTrue(example.newly_processed) self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6)) @@ -132,7 +139,7 @@ def test_evaluate_with_state(self): self.assertEqual(example.usage_summary.uncached.total.num_requests, 0) def test_html_view(self): - exp = test_helper.TestEvaluation() + exp = eval_test_helper.TestEvaluation() exp.debug('debug message') exp.info('info message') exp.warning('warning message', x=1) diff --git a/langfun/core/eval/v2/progress_tracking_test.py b/langfun/core/eval/v2/progress_tracking_test.py index ea13a49..24aaa27 100644 --- a/langfun/core/eval/v2/progress_tracking_test.py +++ b/langfun/core/eval/v2/progress_tracking_test.py @@ -18,9 +18,9 @@ import unittest from langfun.core import console as lf_console +from langfun.core.eval.v2 import eval_test_helper from langfun.core.eval.v2 import progress_tracking # pylint: disable=unused-import from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import -from langfun.core.eval.v2 import test_helper import pyglove as pg @@ -35,7 +35,7 @@ def display(x): display=display ) root_dir = os.path.join(tempfile.gettempdir(), 'test_html_progress_tracker') - experiment = test_helper.test_experiment() + experiment = eval_test_helper.test_experiment() _ = experiment.run(root_dir, 'new', plugins=[]) self.assertIsInstance(result['view'], pg.Html) lf_console._notebook = None @@ -45,7 +45,7 @@ class TqdmProgressTrackerTest(unittest.TestCase): def test_basic(self): root_dir = os.path.join(tempfile.gettempdir(), 'test_tqdm_progress_tracker') - experiment = test_helper.test_experiment() + experiment = eval_test_helper.test_experiment() string_io = io.StringIO() with contextlib.redirect_stderr(string_io): _ = experiment.run(root_dir, 'new', plugins=[]) @@ -55,7 +55,7 @@ def test_with_example_ids(self): root_dir = os.path.join( tempfile.gettempdir(), 'test_tqdm_progress_tracker_with_example_ids' ) - experiment = test_helper.test_experiment() + experiment = eval_test_helper.test_experiment() string_io = io.StringIO() with contextlib.redirect_stderr(string_io): _ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[]) diff --git a/langfun/core/eval/v2/reporting_test.py b/langfun/core/eval/v2/reporting_test.py index 9e27b86..f027a12 100644 --- a/langfun/core/eval/v2/reporting_test.py +++ b/langfun/core/eval/v2/reporting_test.py @@ -15,9 +15,9 @@ import tempfile import unittest +from langfun.core.eval.v2 import eval_test_helper from langfun.core.eval.v2 import reporting from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import -from langfun.core.eval.v2 import test_helper import pyglove as pg @@ -25,7 +25,7 @@ class ReportingTest(unittest.TestCase): def test_reporting(self): root_dir = os.path.join(tempfile.gettempdir(), 'test_reporting') - experiment = test_helper.test_experiment() + experiment = eval_test_helper.test_experiment() reporter = reporting.HtmlReporter() run = experiment.run(root_dir, 'new', plugins=[reporter]) pg.io.path_exists(run.output_path_for(experiment, 'summary.html')) diff --git a/langfun/core/eval/v2/runners.py b/langfun/core/eval/v2/runners.py index caeae77..b4aecd9 100644 --- a/langfun/core/eval/v2/runners.py +++ b/langfun/core/eval/v2/runners.py @@ -65,6 +65,7 @@ def _on_bound(self): with pg.notify_on_change(False): self.plugins.append(progress_tracking.progress_tracker(self.tqdm)) + self._io_pool_lock = threading.Lock() self._io_pool = concurrent.futures.ThreadPoolExecutor(max_workers=16) # TODO(daiyip): render background errors. self._background_last_error = None @@ -76,7 +77,10 @@ def _background_run(*args, **kwargs): func(*args, **kwargs) except Exception as e: # pylint: disable=broad-except self._background_last_error = e - self._io_pool.submit(_background_run, *args, **kwargs) + + with self._io_pool_lock: + if self._io_pool is not None: + self._io_pool.submit(_background_run, *args, **kwargs) def _all_plugins(self, experiment: Experiment) -> Iterator[Plugin]: """Returns all plugins for the experiment.""" @@ -296,7 +300,9 @@ def run(self) -> None: self.background_run(cache.save) # Wait for the background tasks to finish. - self._io_pool.shutdown(wait=True) + with self._io_pool_lock: + self._io_pool, io_pool = None, self._io_pool + io_pool.shutdown(wait=True) @abc.abstractmethod def _run(self, evaluations: list[Evaluation]) -> None: diff --git a/langfun/core/eval/v2/runners_test.py b/langfun/core/eval/v2/runners_test.py index 82e6268..e645f5e 100644 --- a/langfun/core/eval/v2/runners_test.py +++ b/langfun/core/eval/v2/runners_test.py @@ -18,10 +18,11 @@ from typing import Any import unittest +from langfun.core.eval.v2 import eval_test_helper from langfun.core.eval.v2 import example as example_lib from langfun.core.eval.v2 import experiment as experiment_lib from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import -from langfun.core.eval.v2 import test_helper + import pyglove as pg @@ -101,7 +102,7 @@ def assert_same_list(self, actual: list[Any], expected: list[Any]): def test_basic(self): plugin = TestPlugin() - exp = test_helper.test_experiment() + exp = eval_test_helper.test_experiment() root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner') run = exp.run(root_dir, runner='sequential', plugins=[plugin]) @@ -143,7 +144,7 @@ def test_basic(self): def test_raise_if_has_error(self): root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error') - exp = test_helper.TestEvaluation() + exp = eval_test_helper.TestEvaluation() with self.assertRaisesRegex(ValueError, 'x should not be 5'): exp.run( root_dir, runner='sequential', plugins=[], raise_if_has_error=True @@ -154,7 +155,7 @@ def test_raise_if_has_error(self): def test_example_ids(self): root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids') - exp = test_helper.test_experiment() + exp = eval_test_helper.test_experiment() plugin = TestPlugin() _ = exp.run( root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9] @@ -164,7 +165,7 @@ def test_example_ids(self): def test_filter(self): plugin = TestPlugin() - exp = test_helper.test_experiment() + exp = eval_test_helper.test_experiment() root_dir = os.path.join(tempfile.gettempdir(), 'test_filter') _ = exp.run( @@ -193,7 +194,7 @@ def test_inputs(num_examples: int = 10): ) for i in range(num_examples) ] - exp = test_helper.TestEvaluation( + exp = eval_test_helper.TestEvaluation( inputs=test_inputs(num_examples=pg.oneof([2, 4])) ) # Global cache. @@ -234,7 +235,7 @@ class ParallelRunnerTest(RunnerTest): def test_parallel_runner(self): plugin = TestPlugin() - exp = test_helper.test_experiment() + exp = eval_test_helper.test_experiment() root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner') run = exp.run(root_dir, runner='parallel', plugins=[plugin]) @@ -274,7 +275,7 @@ def test_parallel_runner(self): def test_concurrent_startup_delay(self): plugin = TestPlugin() - exp = test_helper.test_experiment() + exp = eval_test_helper.test_experiment() root_dir = os.path.join( tempfile.gettempdir(), 'test_concurrent_startup_delay' ) @@ -290,7 +291,7 @@ class DebugRunnerTest(RunnerTest): def test_debug_runner(self): plugin = TestPlugin() - exp = test_helper.test_experiment() + exp = eval_test_helper.test_experiment() root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner') run = exp.run(root_dir, runner='debug', plugins=[plugin])