Skip to content

Commit

Permalink
lf.eval.v2 enhancements:
Browse files Browse the repository at this point in the history
1) Expose `Evaluation.state` so users could access evaluated examples for plugin developement.
2) Support conditional serializing `lf.eval.v2.Example.input` with `pg.Refs`.

PiperOrigin-RevId: 696284525
  • Loading branch information
daiyip authored and langfun authors committed Nov 14, 2024
1 parent 5f661f7 commit a34ab2e
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 15 deletions.
4 changes: 4 additions & 0 deletions langfun/core/eval/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from langfun.core.eval.v2.metrics import Metric
from langfun.core.eval.v2 import metrics

from langfun.core.eval.v2.experiment import Plugin

from langfun.core.eval.v2.experiment import Runner
from langfun.core.eval.v2 import runners


# pylint: enable=g-bad-import-order
# pylint: enable=g-importing-member
7 changes: 6 additions & 1 deletion langfun/core/eval/v2/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ def __init__(self, path: str):
self._sequence_writer = pg.io.open_sequence(path, 'w')

def add(self, example: Example):
example_blob = pg.to_json_str(example, hide_default_values=True)
example_blob = pg.to_json_str(
example,
hide_default_values=True,
save_ref_value=True,
exclude_input=True
)
with self._lock:
if self._sequence_writer is None:
return
Expand Down
5 changes: 5 additions & 0 deletions langfun/core/eval/v2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ def resource_ids(self) -> set[str]:
# Handling evaluation state.
#

@property
def state(self) -> 'EvaluationState':
"""Returns the state of the evaluation."""
return self._state

def load_state(
self, state_file: str, raise_if_not_exist: bool = False
) -> None:
Expand Down
1 change: 1 addition & 0 deletions langfun/core/eval/v2/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_hyper_evaluation(self):
def test_evaluate(self):
exp = test_helper.TestEvaluation()
example = exp.evaluate(Example(id=3))
self.assertIs(exp.state.get(3), example)
self.assertTrue(example.newly_processed)
self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
self.assertEqual(example.output, 6)
Expand Down
35 changes: 22 additions & 13 deletions langfun/core/eval/v2/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ def elapse(self) -> float | None:
return self.execution_status['evaluate'].elapse
return None

def to_json(self, **kwargs) -> dict[str, Any]:
def to_json(self, *, exclude_input: bool = False, **kwargs):
"""Returns the JSON representation of the item."""
return self.to_json_dict(
fields=dict(
id=(self.id, None),
# NOTE(daiyip): We do not write `input` to JSON as it will be
# loaded from the input functor. This allows us to support
# non-serializable examples.
input=(
self.input if not exclude_input else pg.MISSING_VALUE,
pg.MISSING_VALUE
),
output=(self.output, pg.MISSING_VALUE),
error=(self.error, None),
metadata=(self.metadata, {}),
Expand All @@ -99,12 +100,17 @@ def from_json(
cls,
json_value: dict[str, Any],
*,
example_input_by_id: Callable[[int], Any],
example_input_by_id: Callable[[int], Any] | None = None,
**kwargs
) -> 'Example':
"""Creates an example from the JSON representation."""
example_id = json_value.get('id')
example_input = example_input_by_id(example_id)
if example_input_by_id:
example_input = example_input_by_id(example_id)
else:
example_input = json_value.pop('input', pg.MISSING_VALUE)
if example_input is not pg.MISSING_VALUE:
example_input = pg.from_json(example_input, **kwargs)
json_value['input'] = example_input

# NOTE(daiyip): We need to load the types of the examples into the
Expand Down Expand Up @@ -205,26 +211,29 @@ def _render_header():
)

def _render_content():
def _tab(label, key):
def _tab(label, key, default):
field = getattr(self, key)
if pg.MISSING_VALUE == field or not field:
if default == field:
return None
return pg.views.html.controls.Tab(
label=label,
content=view.render(
field,
root_path=root_path + key,
collapse_level=None,
**view.get_passthrough_kwargs(**kwargs),
),
)
tabs = [
_tab('Input', 'input'),
_tab('Output', 'output'),
_tab('Output Metadata', 'metadata'),
_tab('Error', 'error'),
_tab('Input', 'input', pg.MISSING_VALUE),
_tab('Output', 'output', pg.MISSING_VALUE),
_tab('Output Metadata', 'metadata', {}),
_tab('Error', 'error', None),
]
tabs = [tab for tab in tabs if tab is not None]
return pg.views.html.controls.TabControl(
[tab for tab in tabs if tab is not None]
tabs,
len(tabs) - 1,
)

return pg.Html.element(
Expand Down
19 changes: 18 additions & 1 deletion langfun/core/eval/v2/example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,31 @@ class B(pg.Object):
output=inputs[0].a(1),
metadata=dict(b=inputs[0].b())
)
json_str = pg.to_json_str(ex)
# Serialize without input.
json_str = pg.to_json_str(ex, exclude_input=True)
self.assertEqual(
pg.from_json_str(
json_str,
example_input_by_id=lambda i: inputs[i - 1]
),
ex
)
pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.clear()
v = pg.from_json_str(json_str, auto_dict=True)
v.output.pop('type_name')
v.metadata.b.pop('type_name')
self.assertEqual(
v,
Example(
id=1,
output=pg.Dict(x=1),
metadata=dict(b=pg.Dict(x=1, y=2)),
)
)
# Serialize with input.
ex = Example(id=2, input=pg.Dict(x=1), output=pg.Dict(x=2))
json_str = pg.to_json_str(ex, exclude_input=False)
self.assertEqual(pg.from_json_str(json_str), ex)

def test_html_view(self):
ex = Example(
Expand Down

0 comments on commit a34ab2e

Please sign in to comment.