diff --git a/langfun/core/agentic/action.py b/langfun/core/agentic/action.py index e0c42a6..1b546c7 100644 --- a/langfun/core/agentic/action.py +++ b/langfun/core/agentic/action.py @@ -14,8 +14,10 @@ """Base classes for agentic actions.""" import abc -from typing import Annotated, Any, Optional, Union +import contextlib +from typing import Annotated, Any, Iterable, Iterator, Optional, Type, Union import langfun.core as lf +from langfun.core import structured as lf_structured import pyglove as pg @@ -35,12 +37,9 @@ def __call__( self, session: Optional['Session'] = None, **kwargs) -> Any: """Executes the action.""" session = session or Session() - try: - session.begin(self) + with session.track(self): self._result = self.call(session=session, **kwargs) return self._result - finally: - session.end(self) @abc.abstractmethod def call(self, session: 'Session', **kwargs) -> Any: @@ -50,9 +49,20 @@ def call(self, session: 'Session', **kwargs) -> Any: class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension): """A class for capturing the invocation of an action.""" action: Action - result: Any = None + + result: Annotated[ + Any, + 'The result of the action.' + ] = None + execution: Annotated[ - list[Union['ActionInvocation', lf.logging.LogEntry]], + list[ + Union[ + lf_structured.QueryInvocation, + 'ActionInvocation', + lf.logging.LogEntry + ] + ], 'Execution execution.' ] = [] @@ -69,6 +79,18 @@ def child_invocations(self) -> list['ActionInvocation']: """Returns child action invocations.""" return [v for v in self.execution if isinstance(v, ActionInvocation)] + def queries( + self, + include_children: bool = False + ) -> Iterable[lf_structured.QueryInvocation]: + """Iterates over queries from the current invocation.""" + for v in self.execution: + if isinstance(v, lf_structured.QueryInvocation): + yield v + elif isinstance(v, ActionInvocation): + if include_children: + yield from v.queries(include_children=True) + def _html_tree_view_summary( self, *, view: pg.views.html.HtmlTreeView, **kwargs ): @@ -190,29 +212,57 @@ def current_invocation(self) -> ActionInvocation: assert self._invocation_stack return self._invocation_stack[-1] - def begin(self, action: Action): - """Signal the beginning of the execution of an action.""" + @contextlib.contextmanager + def track(self, action: Action) -> Iterator[ActionInvocation]: + """Track the execution of an action.""" new_invocation = ActionInvocation(pg.maybe_ref(action)) with pg.notify_on_change(False): self.current_invocation.execution.append(new_invocation) self._invocation_stack.append(new_invocation) - def end(self, action: Action): - """Signal the end of the execution of an action.""" - assert self._invocation_stack - invocation = self._invocation_stack.pop(-1) - invocation.rebind( - result=action.result, skip_notification=True, raise_on_no_change=False - ) - assert invocation.action is action, (invocation.action, action) - assert self._invocation_stack, self._invocation_stack - - if len(self._invocation_stack) == 1: - self.root_invocation.rebind( - result=invocation.result, - skip_notification=True, - raise_on_no_change=False + try: + yield new_invocation + finally: + assert self._invocation_stack + invocation = self._invocation_stack.pop(-1) + invocation.rebind( + result=action.result, skip_notification=True, raise_on_no_change=False ) + assert invocation.action is action, (invocation.action, action) + assert self._invocation_stack, self._invocation_stack + + if len(self._invocation_stack) == 1: + self.root_invocation.rebind( + result=invocation.result, + skip_notification=True, + raise_on_no_change=False + ) + + def query( + self, + prompt: Union[str, lf.Template, Any], + schema: Union[ + lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None + ] = None, + default: Any = lf.RAISE_IF_HAS_ERROR, + *, + lm: lf.LanguageModel | None = None, + examples: list[lf_structured.MappingExample] | None = None, + **kwargs + ) -> Any: + """Calls `lf.query` and associates it with the current invocation.""" + with lf_structured.track_queries() as queries: + output = lf_structured.query( + prompt, + schema=schema, + default=default, + lm=lm, + examples=examples, + **kwargs + ) + with pg.notify_on_change(False): + self.current_invocation.execution.extend(queries) + return output def _log(self, level: lf.logging.LogLevel, message: str, **kwargs): with pg.notify_on_change(False): diff --git a/langfun/core/agentic/action_test.py b/langfun/core/agentic/action_test.py index 25234e2..639ce92 100644 --- a/langfun/core/agentic/action_test.py +++ b/langfun/core/agentic/action_test.py @@ -17,6 +17,7 @@ import langfun.core as lf from langfun.core.agentic import action as action_lib +from langfun.core.llms import fake class SessionTest(unittest.TestCase): @@ -26,25 +27,35 @@ def test_basics(self): class Bar(action_lib.Action): - def call(self, session, **kwargs): + def call(self, session, *, lm, **kwargs): test.assertIs(session.current_invocation.action, self) session.info('Begin Bar') + session.query('bar', lm=lm) return 2 class Foo(action_lib.Action): x: int - def call(self, session, **kwargs): + def call(self, session, *, lm, **kwargs): test.assertIs(session.current_invocation.action, self) session.info('Begin Foo', x=1) - return self.x + Bar()(session) + session.query('foo', lm=lm) + return self.x + Bar()(session, lm=lm) + lm = fake.StaticResponse('lm response') session = action_lib.Session() root = session.root_invocation self.assertIsInstance(root.action, action_lib.RootAction) self.assertIs(session.current_invocation, session.root_invocation) - self.assertEqual(Foo(1)(session), 3) + self.assertEqual(Foo(1)(session, lm=lm), 3) self.assertEqual(len(session.root_invocation.child_invocations), 1) + self.assertEqual(len(list(session.root_invocation.queries())), 0) + self.assertEqual( + len(list(session.root_invocation.queries(include_children=True))), 2 + ) + self.assertEqual( + len(list(session.root_invocation.child_invocations[0].queries())), 1 + ) self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1) self.assertEqual( len(session.root_invocation.child_invocations[0].child_invocations), @@ -55,6 +66,11 @@ def call(self, session, **kwargs): .child_invocations[0].child_invocations[0].logs), 1 ) + self.assertEqual( + len(list(session.root_invocation + .child_invocations[0].child_invocations[0].queries())), + 1 + ) self.assertEqual( len(session.root_invocation .child_invocations[0].child_invocations[0].child_invocations),