Skip to content

Commit

Permalink
Enable query tracking for Langfun agents.
Browse files Browse the repository at this point in the history
By introducing `lf.agentic.Session.query`, we are able to track calls to `lf.query` and associate them with the right action and call venue. It also enables a possibility of returning `lf.query` output from stored trajectory.

After this change, agents developers are encouraged to use `session.query` to replace `lf.query` within an Action for automatic tracking/saving capability.

PiperOrigin-RevId: 702456741
  • Loading branch information
daiyip authored and langfun authors committed Dec 3, 2024
1 parent 0634b85 commit 1220ac6
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 28 deletions.
98 changes: 74 additions & 24 deletions langfun/core/agentic/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
"""Base classes for agentic actions."""

import abc
from typing import Annotated, Any, Optional, Union
import contextlib
from typing import Annotated, Any, Iterable, Iterator, Optional, Type, Union
import langfun.core as lf
from langfun.core import structured as lf_structured
import pyglove as pg


Expand All @@ -35,12 +37,9 @@ def __call__(
self, session: Optional['Session'] = None, **kwargs) -> Any:
"""Executes the action."""
session = session or Session()
try:
session.begin(self)
with session.track(self):
self._result = self.call(session=session, **kwargs)
return self._result
finally:
session.end(self)

@abc.abstractmethod
def call(self, session: 'Session', **kwargs) -> Any:
Expand All @@ -50,9 +49,20 @@ def call(self, session: 'Session', **kwargs) -> Any:
class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
"""A class for capturing the invocation of an action."""
action: Action
result: Any = None

result: Annotated[
Any,
'The result of the action.'
] = None

execution: Annotated[
list[Union['ActionInvocation', lf.logging.LogEntry]],
list[
Union[
lf_structured.QueryInvocation,
'ActionInvocation',
lf.logging.LogEntry
]
],
'Execution execution.'
] = []

Expand All @@ -69,6 +79,18 @@ def child_invocations(self) -> list['ActionInvocation']:
"""Returns child action invocations."""
return [v for v in self.execution if isinstance(v, ActionInvocation)]

def queries(
self,
include_children: bool = False
) -> Iterable[lf_structured.QueryInvocation]:
"""Iterates over queries from the current invocation."""
for v in self.execution:
if isinstance(v, lf_structured.QueryInvocation):
yield v
elif isinstance(v, ActionInvocation):
if include_children:
yield from v.queries(include_children=True)

def _html_tree_view_summary(
self, *, view: pg.views.html.HtmlTreeView, **kwargs
):
Expand Down Expand Up @@ -190,29 +212,57 @@ def current_invocation(self) -> ActionInvocation:
assert self._invocation_stack
return self._invocation_stack[-1]

def begin(self, action: Action):
"""Signal the beginning of the execution of an action."""
@contextlib.contextmanager
def track(self, action: Action) -> Iterator[ActionInvocation]:
"""Track the execution of an action."""
new_invocation = ActionInvocation(pg.maybe_ref(action))
with pg.notify_on_change(False):
self.current_invocation.execution.append(new_invocation)
self._invocation_stack.append(new_invocation)

def end(self, action: Action):
"""Signal the end of the execution of an action."""
assert self._invocation_stack
invocation = self._invocation_stack.pop(-1)
invocation.rebind(
result=action.result, skip_notification=True, raise_on_no_change=False
)
assert invocation.action is action, (invocation.action, action)
assert self._invocation_stack, self._invocation_stack

if len(self._invocation_stack) == 1:
self.root_invocation.rebind(
result=invocation.result,
skip_notification=True,
raise_on_no_change=False
try:
yield new_invocation
finally:
assert self._invocation_stack
invocation = self._invocation_stack.pop(-1)
invocation.rebind(
result=action.result, skip_notification=True, raise_on_no_change=False
)
assert invocation.action is action, (invocation.action, action)
assert self._invocation_stack, self._invocation_stack

if len(self._invocation_stack) == 1:
self.root_invocation.rebind(
result=invocation.result,
skip_notification=True,
raise_on_no_change=False
)

def query(
self,
prompt: Union[str, lf.Template, Any],
schema: Union[
lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
] = None,
default: Any = lf.RAISE_IF_HAS_ERROR,
*,
lm: lf.LanguageModel | None = None,
examples: list[lf_structured.MappingExample] | None = None,
**kwargs
) -> Any:
"""Calls `lf.query` and associates it with the current invocation."""
with lf_structured.track_queries() as queries:
output = lf_structured.query(
prompt,
schema=schema,
default=default,
lm=lm,
examples=examples,
**kwargs
)
with pg.notify_on_change(False):
self.current_invocation.execution.extend(queries)
return output

def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
with pg.notify_on_change(False):
Expand Down
24 changes: 20 additions & 4 deletions langfun/core/agentic/action_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import langfun.core as lf
from langfun.core.agentic import action as action_lib
from langfun.core.llms import fake


class SessionTest(unittest.TestCase):
Expand All @@ -26,25 +27,35 @@ def test_basics(self):

class Bar(action_lib.Action):

def call(self, session, **kwargs):
def call(self, session, *, lm, **kwargs):
test.assertIs(session.current_invocation.action, self)
session.info('Begin Bar')
session.query('bar', lm=lm)
return 2

class Foo(action_lib.Action):
x: int

def call(self, session, **kwargs):
def call(self, session, *, lm, **kwargs):
test.assertIs(session.current_invocation.action, self)
session.info('Begin Foo', x=1)
return self.x + Bar()(session)
session.query('foo', lm=lm)
return self.x + Bar()(session, lm=lm)

lm = fake.StaticResponse('lm response')
session = action_lib.Session()
root = session.root_invocation
self.assertIsInstance(root.action, action_lib.RootAction)
self.assertIs(session.current_invocation, session.root_invocation)
self.assertEqual(Foo(1)(session), 3)
self.assertEqual(Foo(1)(session, lm=lm), 3)
self.assertEqual(len(session.root_invocation.child_invocations), 1)
self.assertEqual(len(list(session.root_invocation.queries())), 0)
self.assertEqual(
len(list(session.root_invocation.queries(include_children=True))), 2
)
self.assertEqual(
len(list(session.root_invocation.child_invocations[0].queries())), 1
)
self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1)
self.assertEqual(
len(session.root_invocation.child_invocations[0].child_invocations),
Expand All @@ -55,6 +66,11 @@ def call(self, session, **kwargs):
.child_invocations[0].child_invocations[0].logs),
1
)
self.assertEqual(
len(list(session.root_invocation
.child_invocations[0].child_invocations[0].queries())),
1
)
self.assertEqual(
len(session.root_invocation
.child_invocations[0].child_invocations[0].child_invocations),
Expand Down

0 comments on commit 1220ac6

Please sign in to comment.