Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refining lf.eval.v2. #382

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions langfun/core/eval/v2/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def on_experiment_start(

# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
if runner.current_run.input_root != runner.current_run.output_root:
experiment.info(
f'Warm starting from directory: {runner.current_run.input_root}.'
)
def _load_state(ckpt_file):
experiment.load_state(ckpt_file)

Expand All @@ -85,8 +89,8 @@ def _load_state(ckpt_file):
):
if error is not None:
experiment.warning(
'Failed to load checkpoint file %s: %s. Skipping the file.',
ckpt_file, error
f'Failed to load checkpoint file {ckpt_file}: {error}. '
'Skipping the file.'
)
super().on_experiment_start(experiment)

Expand Down Expand Up @@ -181,6 +185,10 @@ def on_experiment_start(
return
# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
if runner.current_run.input_root != runner.current_run.output_root:
experiment.info(
f'Warm starting from directory: {runner.current_run.input_root}.'
)
experiment.load_state(
runner.current_run.input_path_for(
experiment, self.checkpoint_filename
Expand Down
6 changes: 3 additions & 3 deletions langfun/core/eval/v2/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import unittest

from langfun.core.eval.v2 import checkpointing
from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
from langfun.core.eval.v2 import test_helper
import pyglove as pg

Example = example_lib.Example
Expand Down Expand Up @@ -56,7 +56,7 @@ class PerExampleCheckpointerTest(unittest.TestCase):

def test_checkpointing(self):
root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
run = experiment.run(
Expand Down Expand Up @@ -89,7 +89,7 @@ class BulkCheckpointerTest(unittest.TestCase):

def test_checkpointing(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename)
run = experiment.run(
Expand Down
35 changes: 21 additions & 14 deletions langfun/core/eval/v2/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import tempfile
import unittest

from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import evaluation as evaluation_lib
from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import experiment as experiment_lib

from langfun.core.eval.v2 import test_helper

import pyglove as pg

Example = example_lib.Example
Expand All @@ -32,17 +31,23 @@
class EvaluationTest(unittest.TestCase):

def test_hyper_evaluation(self):
exp = test_helper.TestEvaluation(
lm=test_helper.TestLLM(offset=pg.oneof(range(3)))
exp = eval_test_helper.TestEvaluation(
lm=eval_test_helper.TestLLM(offset=pg.oneof(range(3)))
)
self.assertFalse(exp.is_leaf)
self.assertTrue(
pg.eq(
exp.children,
[
test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=0)),
test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=1)),
test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=2)),
eval_test_helper.TestEvaluation(
lm=eval_test_helper.TestLLM(offset=0)
),
eval_test_helper.TestEvaluation(
lm=eval_test_helper.TestLLM(offset=1)
),
eval_test_helper.TestEvaluation(
lm=eval_test_helper.TestLLM(offset=2)
),
]
)
)
Expand All @@ -57,19 +62,21 @@ def test_hyper_evaluation(self):
)

def test_input(self):
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
self.assertEqual(exp.num_examples, 10)
exp = test_helper.TestEvaluation(inputs=test_helper.test_inputs(None))
exp = eval_test_helper.TestEvaluation(
inputs=eval_test_helper.test_inputs(None)
)
self.assertEqual(exp.num_examples, 20)
@pg.functor
def my_inputs():
yield pg.Dict(x=1, y=2)
yield pg.Dict(x=3, y=4)
exp = test_helper.TestEvaluation(inputs=my_inputs())
exp = eval_test_helper.TestEvaluation(inputs=my_inputs())
self.assertEqual(exp.num_examples, 2)

def test_evaluate(self):
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
example = exp.evaluate(Example(id=3))
self.assertIs(exp.state.get(3), example)
self.assertTrue(example.newly_processed)
Expand All @@ -85,7 +92,7 @@ def test_evaluate(self):
self.assertIsNotNone(example.start_time)
self.assertIsNotNone(example.end_time)

exp = test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=1))
exp = eval_test_helper.TestEvaluation(lm=eval_test_helper.TestLLM(offset=1))
example = exp.evaluate(3)
self.assertTrue(example.newly_processed)
self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
Expand All @@ -109,7 +116,7 @@ def test_evaluate_with_state(self):
pg.io.mkdirs(eval_dir, exist_ok=True)
state_file = os.path.join(eval_dir, 'state.jsonl')
with pg.io.open_sequence(state_file, 'w') as f:
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
example = exp.evaluate(3)
self.assertTrue(example.newly_processed)
self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
Expand All @@ -132,7 +139,7 @@ def test_evaluate_with_state(self):
self.assertEqual(example.usage_summary.uncached.total.num_requests, 0)

def test_html_view(self):
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
exp.debug('debug message')
exp.info('info message')
exp.warning('warning message', x=1)
Expand Down
8 changes: 4 additions & 4 deletions langfun/core/eval/v2/progress_tracking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import unittest

from langfun.core import console as lf_console
from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import progress_tracking # pylint: disable=unused-import
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
from langfun.core.eval.v2 import test_helper
import pyglove as pg


Expand All @@ -35,7 +35,7 @@ def display(x):
display=display
)
root_dir = os.path.join(tempfile.gettempdir(), 'test_html_progress_tracker')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
_ = experiment.run(root_dir, 'new', plugins=[])
self.assertIsInstance(result['view'], pg.Html)
lf_console._notebook = None
Expand All @@ -45,7 +45,7 @@ class TqdmProgressTrackerTest(unittest.TestCase):

def test_basic(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_tqdm_progress_tracker')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
string_io = io.StringIO()
with contextlib.redirect_stderr(string_io):
_ = experiment.run(root_dir, 'new', plugins=[])
Expand All @@ -55,7 +55,7 @@ def test_with_example_ids(self):
root_dir = os.path.join(
tempfile.gettempdir(), 'test_tqdm_progress_tracker_with_example_ids'
)
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
string_io = io.StringIO()
with contextlib.redirect_stderr(string_io):
_ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[])
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/eval/v2/reporting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
import tempfile
import unittest

from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import reporting
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
from langfun.core.eval.v2 import test_helper
import pyglove as pg


class ReportingTest(unittest.TestCase):

def test_reporting(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_reporting')
experiment = test_helper.test_experiment()
experiment = eval_test_helper.test_experiment()
reporter = reporting.HtmlReporter()
run = experiment.run(root_dir, 'new', plugins=[reporter])
pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
Expand Down
10 changes: 8 additions & 2 deletions langfun/core/eval/v2/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _on_bound(self):
with pg.notify_on_change(False):
self.plugins.append(progress_tracking.progress_tracker(self.tqdm))

self._io_pool_lock = threading.Lock()
self._io_pool = concurrent.futures.ThreadPoolExecutor(max_workers=16)
# TODO(daiyip): render background errors.
self._background_last_error = None
Expand All @@ -76,7 +77,10 @@ def _background_run(*args, **kwargs):
func(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
self._background_last_error = e
self._io_pool.submit(_background_run, *args, **kwargs)

with self._io_pool_lock:
if self._io_pool is not None:
self._io_pool.submit(_background_run, *args, **kwargs)

def _all_plugins(self, experiment: Experiment) -> Iterator[Plugin]:
"""Returns all plugins for the experiment."""
Expand Down Expand Up @@ -296,7 +300,9 @@ def run(self) -> None:
self.background_run(cache.save)

# Wait for the background tasks to finish.
self._io_pool.shutdown(wait=True)
with self._io_pool_lock:
self._io_pool, io_pool = None, self._io_pool
io_pool.shutdown(wait=True)

@abc.abstractmethod
def _run(self, evaluations: list[Evaluation]) -> None:
Expand Down
19 changes: 10 additions & 9 deletions langfun/core/eval/v2/runners_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from typing import Any
import unittest

from langfun.core.eval.v2 import eval_test_helper
from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import experiment as experiment_lib
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
from langfun.core.eval.v2 import test_helper

import pyglove as pg


Expand Down Expand Up @@ -101,7 +102,7 @@ def assert_same_list(self, actual: list[Any], expected: list[Any]):

def test_basic(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
run = exp.run(root_dir, runner='sequential', plugins=[plugin])

Expand Down Expand Up @@ -143,7 +144,7 @@ def test_basic(self):

def test_raise_if_has_error(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
exp = test_helper.TestEvaluation()
exp = eval_test_helper.TestEvaluation()
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
exp.run(
root_dir, runner='sequential', plugins=[], raise_if_has_error=True
Expand All @@ -154,7 +155,7 @@ def test_raise_if_has_error(self):

def test_example_ids(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
plugin = TestPlugin()
_ = exp.run(
root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
Expand All @@ -164,7 +165,7 @@ def test_example_ids(self):

def test_filter(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')

_ = exp.run(
Expand Down Expand Up @@ -193,7 +194,7 @@ def test_inputs(num_examples: int = 10):
) for i in range(num_examples)
]

exp = test_helper.TestEvaluation(
exp = eval_test_helper.TestEvaluation(
inputs=test_inputs(num_examples=pg.oneof([2, 4]))
)
# Global cache.
Expand Down Expand Up @@ -234,7 +235,7 @@ class ParallelRunnerTest(RunnerTest):

def test_parallel_runner(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
run = exp.run(root_dir, runner='parallel', plugins=[plugin])

Expand Down Expand Up @@ -274,7 +275,7 @@ def test_parallel_runner(self):

def test_concurrent_startup_delay(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(
tempfile.gettempdir(), 'test_concurrent_startup_delay'
)
Expand All @@ -290,7 +291,7 @@ class DebugRunnerTest(RunnerTest):

def test_debug_runner(self):
plugin = TestPlugin()
exp = test_helper.test_experiment()
exp = eval_test_helper.test_experiment()
root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
run = exp.run(root_dir, runner='debug', plugins=[plugin])

Expand Down
Loading