From 90ccbce8e862bdf295bf8767259a18f501a74921 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Wed, 13 Nov 2024 20:10:36 -0800 Subject: [PATCH] `lf.agentic`: Langfun's framework for LLM agents. PiperOrigin-RevId: 696366841 --- langfun/__init__.py | 4 + langfun/core/agentic/__init__.py | 26 +++ langfun/core/agentic/action.py | 247 ++++++++++++++++++++++++++ langfun/core/agentic/action_test.py | 84 +++++++++ langfun/core/concurrent.py | 1 + langfun/core/eval/__init__.py | 1 + langfun/core/eval/action_eval.py | 129 ++++++++++++++ langfun/core/eval/action_eval_test.py | 86 +++++++++ 8 files changed, 578 insertions(+) create mode 100644 langfun/core/agentic/__init__.py create mode 100644 langfun/core/agentic/action.py create mode 100644 langfun/core/agentic/action_test.py create mode 100644 langfun/core/eval/action_eval.py create mode 100644 langfun/core/eval/action_eval_test.py diff --git a/langfun/__init__.py b/langfun/__init__.py index fa8c41d..685ab3f 100644 --- a/langfun/__init__.py +++ b/langfun/__init__.py @@ -53,6 +53,10 @@ from langfun.core import llms lm_cache = llms.cache.lm_cache +from langfun.core import agentic +Action = agentic.Action +Session = agentic.Session + from langfun.core import memories from langfun.core import modalities diff --git a/langfun/core/agentic/__init__.py b/langfun/core/agentic/__init__.py new file mode 100644 index 0000000..0258f76 --- /dev/null +++ b/langfun/core/agentic/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Langfun agentic framework..""" + +# pylint: disable=g-bad-import-order +# pylint: disable=g-importing-member +# pylint: disable=g-import-not-at-top + +from langfun.core.agentic.action import Action +from langfun.core.agentic.action import ActionInvocation +from langfun.core.agentic.action import Session + +# pylint: enable=g-bad-import-order +# pylint: enable=g-importing-member +# pylint: enable=g-import-not-at-top diff --git a/langfun/core/agentic/action.py b/langfun/core/agentic/action.py new file mode 100644 index 0000000..b9e7c1d --- /dev/null +++ b/langfun/core/agentic/action.py @@ -0,0 +1,247 @@ +# Copyright 2024 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base classes for agentic actions.""" + +import abc +from typing import Annotated, Any, Optional, Union +import langfun.core as lf +import pyglove as pg + + +class Action(pg.Object): + """Base class for agent actions.""" + + def _on_bound(self): + super()._on_bound() + self._result = None + + @property + def result(self) -> Any: + """Returns the result of the action.""" + return self._result + + def __call__( + self, session: Optional['Session'] = None, **kwargs) -> Any: + """Executes the action.""" + session = session or Session() + try: + session.begin(self) + self._result = self.call(session=session, **kwargs) + return self._result + finally: + session.end(self) + + @abc.abstractmethod + def call(self, session: 'Session', **kwargs) -> Any: + """Subclasses to implement.""" + + +class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension): + """A class for capturing the invocation of an action.""" + action: Action + result: Any = None + execution: Annotated[ + list[Union['ActionInvocation', lf.logging.LogEntry]], + 'Execution execution.' + ] = [] + + # Allow symbolic assignment without `rebind`. + allow_symbolic_assignment = True + + @property + def logs(self) -> list[lf.logging.LogEntry]: + """Returns logs from execution sequence.""" + return [v for v in self.execution if isinstance(v, lf.logging.LogEntry)] + + @property + def child_invocations(self) -> list['ActionInvocation']: + """Returns child action invocations.""" + return [v for v in self.execution if isinstance(v, ActionInvocation)] + + def _html_tree_view_summary( + self, *, view: pg.views.html.HtmlTreeView, **kwargs + ): + if isinstance(self.action, RootAction): + return None + kwargs.pop('title') + return view.summary( + self, + title=view.render( + self.action, name='action', collapse_level=0, + css_classes='invocation-title', + ), + **kwargs + ) + + def _html_tree_view_content( + self, + *, + root_path: pg.KeyPath | None = None, + collapse_level: int | None = None, + view: pg.views.html.HtmlTreeView, + **kwargs + ): + prepare_phase = [] + current_phase = prepare_phase + action_phases = [] + for item in self.execution: + if isinstance(item, ActionInvocation): + current_phase = [] + action_phases.append(current_phase) + current_phase.append(item) + + def _render_phase( + phase: list[ActionInvocation | lf.logging.LogEntry] + ) -> pg.Html.WritableTypes: + return pg.Html.element( + 'div', + [ + view.render(item) for item in phase + ] + ) + + def _render_action_phases( + phases: list[list[ActionInvocation | lf.logging.LogEntry]] + ) -> pg.Html.WritableTypes: + if len(phases) == 1: + return _render_phase(phases[0]) + return pg.views.html.controls.TabControl( + [ + pg.views.html.controls.Tab( + label=f'Step {i + 1}', + content=_render_phase(phase), + ) + for i, phase in enumerate(phases) + ], + ) + + result_name = 'final_result' if isinstance( + self.action, RootAction) else 'result' + return pg.Html.element( + 'div', + [ + view.render( + self.result, + name=result_name, + css_classes=[ + f'invocation-{result_name}'.replace('_', '-') + ] + ), + _render_phase(prepare_phase) if prepare_phase else None, + _render_action_phases(action_phases) + ] + ) + + @classmethod + def _html_tree_view_css_styles(cls) -> list[str]: + return super()._html_tree_view_css_styles() + [ + """ + details.invocation-title { + display: inline-block; + background-color: #b1f0ff; + border: 1px solid white; + } + details.invocation-result { + border: 1px solid #eee; + } + details.invocation-final-result { + border: 1px solid #eee; + background-color: #fef78f; + } + """ + ] + + +class RootAction(Action): + """A placeholder action for the root of the action tree.""" + + def call(self, session: 'Session', **kwargs) -> Any: + raise NotImplementedError('Shall not be called.') + + +class Session(pg.Object): + """Session for performing an agentic task.""" + + root_invocation: ActionInvocation = ActionInvocation(RootAction()) + + def _on_bound(self): + super()._on_bound() + self._invocation_stack = [self.root_invocation] + + @property + def final_result(self) -> Any: + """Returns the final result of the session.""" + return self.root_invocation.result + + @property + def current_invocation(self) -> ActionInvocation: + """Returns the current invocation.""" + assert self._invocation_stack + return self._invocation_stack[-1] + + def begin(self, action: Action): + """Signal the beginning of the execution of an action.""" + new_invocation = ActionInvocation(pg.maybe_ref(action)) + with pg.notify_on_change(False): + self.current_invocation.execution.append(new_invocation) + self._invocation_stack.append(new_invocation) + + def end(self, action: Action): + """Signal the end of the execution of an action.""" + assert self._invocation_stack + invocation = self._invocation_stack.pop(-1) + invocation.result = action.result + assert self._invocation_stack + + if len(self._invocation_stack) == 1: + self.root_invocation.rebind( + result=invocation.result, + skip_notification=True, + raise_on_no_change=False + ) + + def _log(self, level: lf.logging.LogLevel, message: str, **kwargs): + with pg.notify_on_change(False): + self.current_invocation.execution.append( + lf.logging.log( + level, message, indent=len(self._invocation_stack) - 1, **kwargs + ) + ) + + def debug(self, message: str, **kwargs): + """Logs a debug message to the session.""" + self._log('debug', message, **kwargs) + + def info(self, message: str, **kwargs): + """Logs an info message to the session.""" + self._log('info', message, **kwargs) + + def warning(self, message: str, **kwargs): + """Logs a warning message to the session.""" + self._log('warning', message, **kwargs) + + def error(self, message: str, **kwargs): + """Logs an error message to the session.""" + self._log('error', message, **kwargs) + + def fatal(self, message: str, **kwargs): + """Logs a fatal message to the session.""" + self._log('fatal', message, **kwargs) + + def as_message(self) -> lf.AIMessage: + """Returns the session as a message.""" + return lf.AIMessage( + 'Agentic task session.', + result=self.root_invocation + ) diff --git a/langfun/core/agentic/action_test.py b/langfun/core/agentic/action_test.py new file mode 100644 index 0000000..25234e2 --- /dev/null +++ b/langfun/core/agentic/action_test.py @@ -0,0 +1,84 @@ +# Copyright 2024 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for base action.""" + +import unittest + +import langfun.core as lf +from langfun.core.agentic import action as action_lib + + +class SessionTest(unittest.TestCase): + + def test_basics(self): + test = self + + class Bar(action_lib.Action): + + def call(self, session, **kwargs): + test.assertIs(session.current_invocation.action, self) + session.info('Begin Bar') + return 2 + + class Foo(action_lib.Action): + x: int + + def call(self, session, **kwargs): + test.assertIs(session.current_invocation.action, self) + session.info('Begin Foo', x=1) + return self.x + Bar()(session) + + session = action_lib.Session() + root = session.root_invocation + self.assertIsInstance(root.action, action_lib.RootAction) + self.assertIs(session.current_invocation, session.root_invocation) + self.assertEqual(Foo(1)(session), 3) + self.assertEqual(len(session.root_invocation.child_invocations), 1) + self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1) + self.assertEqual( + len(session.root_invocation.child_invocations[0].child_invocations), + 1 + ) + self.assertEqual( + len(session.root_invocation + .child_invocations[0].child_invocations[0].logs), + 1 + ) + self.assertEqual( + len(session.root_invocation + .child_invocations[0].child_invocations[0].child_invocations), + 0 + ) + self.assertIs(session.current_invocation, session.root_invocation) + self.assertIs(session.final_result, 3) + self.assertIn( + 'invocation-final-result', + session.to_html().content, + ) + + def test_log(self): + session = action_lib.Session() + session.debug('hi', x=1, y=2) + session.info('hi', x=1, y=2) + session.warning('hi', x=1, y=2) + session.error('hi', x=1, y=2) + session.fatal('hi', x=1, y=2) + + def test_as_message(self): + session = action_lib.Session() + self.assertIsInstance(session.as_message(), lf.AIMessage) + + +if __name__ == '__main__': + unittest.main() diff --git a/langfun/core/concurrent.py b/langfun/core/concurrent.py index c2e8704..0e57cda 100644 --- a/langfun/core/concurrent.py +++ b/langfun/core/concurrent.py @@ -423,6 +423,7 @@ def install( status: dict[str, Any] | None = None, ) -> int: """Installs a progress bar and returns a reference id.""" + print('INSTALL', label, total, color, status) with cls._lock: settings = ProgressBar.Settings(label, total, color, status) bar_id = id(settings) diff --git a/langfun/core/eval/__init__.py b/langfun/core/eval/__init__.py index 99956ea..37da5cc 100644 --- a/langfun/core/eval/__init__.py +++ b/langfun/core/eval/__init__.py @@ -39,6 +39,7 @@ from langfun.core.eval.matching import Matching from langfun.core.eval.scoring import Scoring +from langfun.core.eval.action_eval import ActionEval # Experiment patching. from langfun.core.eval.patching import patch_member diff --git a/langfun/core/eval/action_eval.py b/langfun/core/eval/action_eval.py new file mode 100644 index 0000000..447adc0 --- /dev/null +++ b/langfun/core/eval/action_eval.py @@ -0,0 +1,129 @@ +# Copyright 2024 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation (v1) for Langfun agentic actions.""" + +import io +import os +from typing import Any + +import langfun.core as lf +from langfun.core import agentic as lf_agentic +from langfun.core.eval import matching as matching_lib +import pyglove as pg + + +@pg.functor() +def _dummy_schema(): + return int + + +class ExampleView(pg.Object): + id: int + input: Any + output: Any + error: str | None = None + + +class ActionEval(matching_lib.Matching): + """Base class for action evaluations. + + The input function should returns a list of pg.Dict, with `action` and + `groundtruth` fields. + """ + # We override the schema and prompt to dummy values since they are not used. + schema_fn = _dummy_schema() + prompt = '' + + def process(self, example: pg.Dict, **kwargs): + action = example.action + session = lf_agentic.Session() + action(session=session, lm=self.lm, **kwargs) + return session.as_message() + + def answer(self, output: Any, example: pg.Dict) -> Any: + return output + + def groundtruth(self, example: Any) -> Any: + return example.groundtruth + + def audit( + self, + example_idx: int, + example: Any, + message: lf.Message | None, + error: Exception | None = None, + dryrun: bool = False, + ): + super().audit(example_idx, example, message, error, dryrun) + # Write each example to HTML. + if not dryrun and self.dir: + def _save_html(): + ExampleView( + example_idx, + example, + None if message is None else message.result, + error + ).to_html( + collapse_level=None, + enable_summary_tooltip=False, + ).save( + os.path.join(self.dir, f'example_{example_idx}.html') + ) + # Write HTML in a separate thread to avoid blocking the main thread. + lf.concurrent.get_executor( + 'background_eval_io', max_workers=16 + ).submit(_save_html) + + def _render_mismatches(self, s: io.StringIO) -> None: + s.write('

Mismatches (Incorrect)

') + first_url = None + mismatched_ids = sorted([ + example_idx for example_idx, *_ in self.mismatches + ]) + for example_idx in mismatched_ids: + url = os.path.join(self.dir, f'example_{example_idx}.html') + if first_url is None: + first_url = url + s.write( + f'' + f'{example_idx} ' + ) + if first_url: + s.write( + '' + ) + else: + s.write('No mismatches found.') + + def _render_matches(self, s: io.StringIO) -> None: + s.write('

Matches (correct)

') + first_url = None + matched_ids = sorted([ + example_idx for example_idx, *_ in self.matches + ]) + for example_idx in matched_ids: + url = os.path.join(self.dir, f'example_{example_idx}.html') + if first_url is None: + first_url = url + s.write( + f'{example_idx} ' + ) + if first_url: + s.write( + '' + ) + else: + s.write('No matches found.') diff --git a/langfun/core/eval/action_eval_test.py b/langfun/core/eval/action_eval_test.py new file mode 100644 index 0000000..b51806c --- /dev/null +++ b/langfun/core/eval/action_eval_test.py @@ -0,0 +1,86 @@ +# Copyright 2024 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for action evaluation.""" + +import unittest + +from langfun.core import agentic as lf_agentic +from langfun.core import llms as lf_llms +from langfun.core.eval import action_eval +import pyglove as pg + + +class ActionEvalTest(unittest.TestCase): + + def test_basics(self): + + class Foo(lf_agentic.Action): + x: int + + def call(self, session, **kwargs): + del session, kwargs + return self.x + + @pg.functor() + def foo_inputs(): + return [ + pg.Dict(action=Foo(1), groundtruth=1), + pg.Dict(action=Foo(2), groundtruth=1), + ] + + class FooEval(action_eval.ActionEval): + lm = lf_llms.Echo() + inputs = foo_inputs() + + s = FooEval() + result = s.run(summary=False) + pg.print(result) + self.assertEqual( + result, + dict( + experiment_setup=dict( + id=s.id, + dir=None, + model='Echo', + prompt_template='', + method='query', + schema_fn='_dummy_schema()' + ), + cache_stats=dict( + use_cache=True, + num_queries=0, + num_hits=0, + num_updates=0, + ), + metrics=dict( + total=2, + failures=0, + failure_rate=0.0, + oop_failures=0, + oop_failure_rate=0.0, + non_oop_failures=0, + non_oop_failure_rate=0.0, + failure_breakdown={}, + num_matches=0, + match_rate=0.0, + num_mismatches=2, + mismatch_rate=1.0 + ), + usage=None + ) + ) + + +if __name__ == '__main__': + unittest.main()