Skip to content

Commit

Permalink
Langfun Evaluation Framework V2
Browse files Browse the repository at this point in the history
Key Features:

1. **Cleaner and more flexible design**: Unlike V1, where evaluation was tightly integrated with lf.query for single-turn LLM assessments, V2 introduces a decoupled structure. This allows for a cleaner, more flexible design that can accommodate multi-turn logic, such as agent interactions, and clearly defines what elements users should override.

2. **Automatic Sweeping with pg.oneof**: V2 continues to support parameter sweeping with pg.oneof, as in V1. However, unlike V1, where only specific parameters could be swept, V2 allows any member variable within the user's class to be sweepable using pg.oneof.

3. **Multi-metric support**: V2 evaluation now has two distinct phases for each example: processing and metric auditing. This separation lets users recompute metrics based on processed examples without needing to reprocess them, improving efficiency.

4. **Improved running organization**: Evaluations are no longer bound to a single root_dir. Instead, root_dir is specified at runtime, with options to select latest, new, or a specific run ID to resume or start a new experiment. This structure makes it easier to refresh and manage ongoing experiments.

5. **Enhanced progress reporting**: Beyond providing a TQDM progress bar, Langfun now integrates PyGlove’s latest HTML summary features, updated every minute. This enhanced view provides more granular information, such as execution time and total tokens used.

6. **Revamped HTML reporting**: V2 introduces a more user-friendly and detailed HTML reporting interface, with streaming capabilities that allow users to analyze individual examples as they are processed, without waiting for the full experiment to complete.

7. **Enhanced LLM cache options**: V2 supports flexible LLM cache settings, including 'global', 'per_dataset', and 'no', to manage cache more effectively across experiments.

8. **Checkpointing**: For the first time, checkpointing is available, enabling experiments to pick up where they left off. Unlike the LLM cache, checkpointing reuses processed examples while allowing metric recalculation.

9. **Plugin support**: The V2 framework is highly extensible with plugins, enabling listeners to respond to various evaluation events. Common plotting support and additional plugins will be introduced in future updates.

PiperOrigin-RevId: 695536246
  • Loading branch information
daiyip authored and langfun authors committed Nov 12, 2024
1 parent 2dfd905 commit 69474a9
Show file tree
Hide file tree
Showing 27 changed files with 5,372 additions and 17 deletions.
12 changes: 10 additions & 2 deletions langfun/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,20 @@ def under_notebook() -> bool:
return bool(_notebook)


def display(value: Any, clear: bool = False) -> None: # pylint: disable=redefined-outer-name
def display(value: Any, clear: bool = False) -> Any: # pylint: disable=redefined-outer-name
"""Displays object in current notebook cell."""
if _notebook is not None:
if clear:
_notebook.clear_output()
_notebook.display(value)
return _notebook.display(value)
return None


def run_script(javascript: str) -> Any:
"""Runs JavaScript in current notebook cell."""
if _notebook is not None:
return _notebook.display(_notebook.Javascript(javascript))
return


def clear() -> None:
Expand Down
17 changes: 17 additions & 0 deletions langfun/core/console_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest

from langfun.core import console
import pyglove as pg


class ConsoleTest(unittest.TestCase):
Expand All @@ -32,6 +33,22 @@ def test_write(self):

def test_under_notebook(self):
self.assertFalse(console.under_notebook())
console._notebook = True
self.assertTrue(console.under_notebook())
console._notebook = None

def test_notebook_interaction(self):
console._notebook = pg.Dict(
display=lambda x: x, Javascript=lambda x: x, clear_output=lambda: None)
self.assertEqual(console.display('hi', clear=True), 'hi')
self.assertEqual(
console.run_script('console.log("hi")'),
'console.log("hi")'
)
console.clear()
console._notebook = None
self.assertIsNone(console.display('hi'))
self.assertIsNone(console.run_script('console.log("hi")'))


if __name__ == '__main__':
Expand Down
2 changes: 2 additions & 0 deletions langfun/core/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# pylint: disable=g-importing-member
# pylint: disable=g-bad-import-order

from langfun.core.eval import v2

from langfun.core.eval.base import register
from langfun.core.eval.base import registered_names
from langfun.core.eval.base import get_evaluations
Expand Down
34 changes: 34 additions & 0 deletions langfun/core/eval/v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2024 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""langfun eval framework v2."""

# pylint: disable=g-importing-member
# pylint: disable=g-bad-import-order
from langfun.core.eval.v2.experiment import Experiment
from langfun.core.eval.v2.experiment import Suite
from langfun.core.eval.v2.evaluation import Evaluation

from langfun.core.eval.v2.example import Example
from langfun.core.eval.v2.progress import Progress

from langfun.core.eval.v2.metric_values import MetricValue
from langfun.core.eval.v2.metric_values import Rate
from langfun.core.eval.v2.metric_values import Average
from langfun.core.eval.v2.metrics import Metric
from langfun.core.eval.v2 import metrics

from langfun.core.eval.v2 import runners

# pylint: enable=g-bad-import-order
# pylint: enable=g-importing-member
130 changes: 130 additions & 0 deletions langfun/core/eval/v2/checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2024 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpointing evaluation runs."""
import threading

from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import experiment as experiment_lib
import pyglove as pg

Example = example_lib.Example
Experiment = experiment_lib.Experiment
Runner = experiment_lib.Runner


class Checkpointer(experiment_lib.Plugin):
"""Plugin for checkpointing evaluation runs."""

checkpoint_filename: str = 'checkpoint.bagz'

def _on_bound(self):
super()._on_bound()
self._lock = threading.Lock()
self._state_writer = None

def on_run_start(
self,
runner: Runner,
root: Experiment,
) -> None:
self._state_writer = {}

def on_run_abort(
self,
runner: Runner,
root: Experiment,
error: BaseException
) -> None:
with self._lock:
if self._state_writer is not None:
self._state_writer.clear()

def on_run_complete(
self,
runner: Runner,
root: Experiment,
) -> None:
with self._lock:
assert self._state_writer is not None and not self._state_writer

def on_experiment_start(
self,
runner: Runner,
experiment: Experiment,
) -> None:
"""Creates the checkpoint file."""
if not experiment.is_leaf:
return
# For refresh runs, we don't want to load the previous state.
if not runner.current_run.refresh:
experiment.load_state(
runner.current_run.input_path_for(
experiment, self.checkpoint_filename
),
raise_if_not_exist=False
)
state_writer = StateWriter(
runner.current_run.output_path_for(
experiment, self.checkpoint_filename
)
)
with self._lock:
if self._state_writer is not None:
self._state_writer[experiment.id] = state_writer

def on_experiment_complete(
self,
runner: Runner,
experiment: Experiment,
) -> None:
"""Closes the checkpoint file."""
if not experiment.is_leaf:
return
assert experiment.id in self._state_writer
with self._lock:
if self._state_writer is not None:
del self._state_writer[experiment.id]

def on_example_complete(
self,
runner: Runner,
experiment: Experiment,
example: Example,
) -> None:
"""Saves the example to the checkpoint file."""
assert experiment.id in self._state_writer
if not example.has_error:
runner.background_run(self._state_writer[experiment.id].add, example)


class StateWriter:
"""Thread safe state writer."""

def __init__(self, path: str):
self._lock = threading.Lock()
self._sequence_writer = pg.io.open_sequence(path, 'w')

def add(self, example: Example):
example_blob = pg.to_json_str(example, hide_default_values=True)
with self._lock:
if self._sequence_writer is None:
return
self._sequence_writer.add(example_blob)

def __del__(self):
# Make sure there is no write in progress.
with self._lock:
assert self._sequence_writer is not None
self._sequence_writer.close()
self._sequence_writer = None
89 changes: 89 additions & 0 deletions langfun/core/eval/v2/checkpointing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest

from langfun.core.eval.v2 import checkpointing
from langfun.core.eval.v2 import example as example_lib
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
from langfun.core.eval.v2 import test_helper
import pyglove as pg

Example = example_lib.Example


class StateWriterTest(unittest.TestCase):

def test_basic(self):
file = os.path.join(tempfile.gettempdir(), 'test.jsonl')
writer = checkpointing.StateWriter(file)
example = Example(id=1, input=pg.Dict(x=1), output=2)
writer.add(example)
del writer
self.assertTrue(pg.io.path_exists(file))

def test_error_handling(self):
file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl')
writer = checkpointing.StateWriter(file)
writer.add(Example(id=1, input=pg.Dict(x=1), output=2))

def f():
raise ValueError('Intentional error')

try:
writer.add(f())
except ValueError:
del writer

self.assertTrue(pg.io.path_exists(file))
with pg.io.open_sequence(file, 'r') as f:
self.assertEqual(len(list(iter(f))), 1)


class CheckpointingTest(unittest.TestCase):

def test_checkpointing(self):
root_dir = os.path.join(tempfile.gettempdir(), 'test_checkpointing')
experiment = test_helper.test_experiment()
checkpoint_filename = 'checkpoint.jsonl'
checkpointer = checkpointing.Checkpointer(checkpoint_filename)
run = experiment.run(
root_dir, 'new', runner='sequential', plugins=[checkpointer]
)
self.assertEqual(len(checkpointer._state_writer), 0)
num_processed = {}
for leaf in experiment.leaf_nodes:
ckpt = run.output_path_for(leaf, checkpoint_filename)
self.assertTrue(pg.io.path_exists(ckpt))
with pg.io.open_sequence(ckpt) as f:
self.assertEqual(
len(list(iter(f))),
leaf.progress.num_completed - leaf.progress.num_failed
)
if leaf.id not in num_processed:
self.assertEqual(leaf.progress.num_skipped, 0)
num_processed[leaf.id] = leaf.progress.num_processed

# Run again, should skip existing.
_ = experiment.run(
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
)
self.assertEqual(len(checkpointer._state_writer), 0)
for leaf in experiment.leaf_nodes:
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 69474a9

Please sign in to comment.