Skip to content

Commit

Permalink
Use session to save metadata.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704862545
  • Loading branch information
Langfun Authors committed Dec 11, 2024
1 parent e668f7f commit 89db54c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
20 changes: 9 additions & 11 deletions langfun/core/agentic/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,24 @@ def __call__(

with session.track_action(self):
result = self.call(session=session, **kwargs)
metadata = dict()
if (isinstance(result, tuple)
and len(result) == 2 and isinstance(result[1], dict)):
result, metadata = result

# For the top-level action, we store the session in the metadata.
if new_session:
self._session = session
self._result, self._result_metadata = result, metadata
self._result = result
self._result_metadata = session.current_action.result_metadata
return self._result

@abc.abstractmethod
def call(
self,
session: 'Session',
**kwargs
) -> Union[Any, tuple[Any, dict[str, Any]]]:
def call(self, session: 'Session', **kwargs) -> Any:
"""Calls the action.
Args:
session: The session to use for the action.
**kwargs: Additional keyword arguments to pass to the action.
Returns:
The result of the action or a tuple of (result, result_metadata).
The result of the action.
"""


Expand Down Expand Up @@ -726,6 +719,11 @@ def current_action(self) -> ActionInvocation:
"""Returns the current invocation."""
return self._current_action

def add_metadata(self, **kwargs: Any) -> None:
"""Adds metadata to the current invocation."""
with pg.notify_on_change(False):
self._current_action.result_metadata.update(kwargs)

def phase(self, name: str) -> ContextManager[ExecutionTrace]:
"""Context manager for starting a new execution phase."""
return self.current_action.phase(name)
Expand Down
6 changes: 4 additions & 2 deletions langfun/core/agentic/action_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def call(self, session, *, lm, **kwargs):
test.assertIs(session.current_action.action, self)
session.info('Begin Bar')
session.query('bar', lm=lm)
return 2, dict(note='bar')
session.add_metadata(note='bar')
return 2

class Foo(action_lib.Action):
x: int
Expand All @@ -45,7 +46,8 @@ def call(self, session, *, lm, **kwargs):
session.query('foo', lm=lm)
with session.track_queries():
self.make_additional_query(lm)
return self.x + Bar()(session, lm=lm), dict(note='foo')
session.add_metadata(note='foo')
return self.x + Bar()(session, lm=lm)

def make_additional_query(self, lm):
lf_structured.query('additional query', lm=lm)
Expand Down

0 comments on commit 89db54c

Please sign in to comment.