Skip to content

Commit

Permalink
Introduce lf.track_queries.
Browse files Browse the repository at this point in the history
- Revise the input preparation logic of `lf.query`.
- Add `lf.QueryInvocation` to represent `lf.query` invocations.
- Add `lf.track_queries` context manager to track `lf.query` invocations.

PiperOrigin-RevId: 702220910
  • Loading branch information
daiyip authored and langfun authors committed Dec 3, 2024
1 parent c47d506 commit 5cbdc5a
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 40 deletions.
2 changes: 2 additions & 0 deletions langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
from langfun.core.structured.prompting import query_prompt
from langfun.core.structured.prompting import query_output
from langfun.core.structured.prompting import query_reward
from langfun.core.structured.prompting import QueryInvocation
from langfun.core.structured.prompting import track_queries

from langfun.core.structured.description import DescribeStructure
from langfun.core.structured.description import describe
Expand Down
160 changes: 121 additions & 39 deletions langfun/core/structured/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
"""Symbolic query."""

import contextlib
import functools
from typing import Any, Callable, Type, Union
from typing import Annotated, Any, Callable, Iterator, Type, Union

import langfun.core as lf
from langfun.core.llms import fake
Expand Down Expand Up @@ -119,7 +120,7 @@ def query(
skip_lm: bool = False,
**kwargs,
) -> Any:
"""Parse a natural langugage message based on schema.
"""Queries an language model for a (maybe) structured output.
Examples:
Expand Down Expand Up @@ -189,55 +190,71 @@ class Flight(pg.Object):
"""
# Internal usage logging.

# Normalize query schema.
# When `lf.query` is used for symbolic completion, schema is automatically
# inferred when it is None.
if isinstance(prompt, pg.Symbolic) and prompt.sym_partial and schema is None:
schema = prompt.__class__

# Create a copy of the prompt if it has a parent object, so all child modality
# objects could be referred by path relative to the prompt.
if isinstance(prompt, lf.Template) and prompt.sym_parent:
prompt = prompt.clone()
# Normalize query input.
if isinstance(prompt, (lf.Message, str)):
# Query with structured output.
prompt_kwargs = kwargs.copy()
prompt_kwargs.pop('template_str', None)
query_input = lf.Template.from_value(prompt, **prompt_kwargs)
elif isinstance(prompt, lf.Template):
# Create a copy of the prompt if it has a parent object, so all child
# modality objects could be referred by path relative to the prompt.
query_input = prompt.clone() if prompt.sym_parent is not None else prompt
else:
query_input = schema_lib.mark_missing(prompt)

def _result(message: lf.Message):
return message.text if schema in (None, str) else message.result

if schema in (None, str):
# Query with natural language output.
output = lf.LangFunc.from_value(prompt, **kwargs)(
output_message = lf.LangFunc.from_value(query_input)(
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
)
if response_postprocess:
processed_text = response_postprocess(output.text)
if processed_text != output.text:
output = lf.AIMessage(processed_text, source=output)
return output if returns_message else output.text

# Query with structured output.
prompt_kwargs = kwargs.copy()

# NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
# QueryStructure template string. Therefore, we pop out the argument for
# prompt rendering.
prompt_kwargs.pop('template_str', None)

if isinstance(prompt, (str, lf.Message, lf.Template)):
prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm)
processed_text = response_postprocess(output_message.text)
if processed_text != output_message.text:
output_message = lf.AIMessage(processed_text, source=output_message)
else:
prompt = schema_lib.mark_missing(prompt)

output = _query_structure_cls(protocol)(
input=prompt,
schema=schema,
default=default,
examples=examples,
response_postprocess=response_postprocess,
autofix=autofix if protocol == 'python' else 0,
**kwargs,
)(
lm=lm,
autofix_lm=autofix_lm or lm,
cache_seed=cache_seed,
skip_lm=skip_lm,
)
return output if returns_message else output.result
# Query with structured output.
output_message = _query_structure_cls(protocol)(
input=(
query_input.render(lm=lm)
if isinstance(query_input, lf.Template)
else query_input
),
schema=schema,
default=default,
examples=examples,
response_postprocess=response_postprocess,
autofix=autofix if protocol == 'python' else 0,
**kwargs,
)(
lm=lm,
autofix_lm=autofix_lm or lm,
cache_seed=cache_seed,
skip_lm=skip_lm,
)

# Track the query invocations.
trackers = lf.context_value('__query_trackers__', [])
if trackers:
invocation = QueryInvocation(
input=pg.Ref(query_input),
schema=pg.Ref(schema) if schema is not str else None,
output=pg.Ref(_result(output_message)),
lm=pg.Ref(lm),
examples=pg.Ref(examples),
)
for tracker in trackers:
tracker.append(invocation)
return output_message if returns_message else _result(output_message)


def query_prompt(
Expand Down Expand Up @@ -320,3 +337,68 @@ def _reward(self, input, expected_output, metadata): # pylint: disable=redefine
args = [self, input, expected_output, metadata]
return cls.__reward__(*args[:num_args])
return _reward


class QueryInvocation(pg.Object):
"""A class to represent the invocation of `lf.query`."""

input: Annotated[
Union[lf.Template, pg.Symbolic],
'Mapping input of `lf.query`.'
]
schema: pg.typing.Annotated[
schema_lib.schema_spec(noneable=True),
'Schema of `lf.query`.'
]
output: Annotated[
Union[lf.Message, pg.Object],
'Mapping output of `lf.query`.'
]
lm: Annotated[
lf.LanguageModel,
'Language model used for `lf.query`.'
]
examples: Annotated[
list[mapping.MappingExample],
'Fewshot exemplars for `lf.query`.'
] = []


@contextlib.contextmanager
def track_queries(
include_child_scopes: bool = True
) -> Iterator[list[QueryInvocation]]:
"""Track all queries made during the context.
Example:
```
with lf.track_queries() as queries:
lf.query('hi', lm=lm)
lf.query('What is this {{image}}?', lm=lm, image=image)
print(queries)
```
Args:
include_child_scopes: If True, the queries made in child scopes will be
included in the returned list. Otherwise, only the queries made in the
current scope will be included.
Yields:
A list of `QueryInvocation` objects representing the queries made during
the context.
"""
trackers = lf.context_value('__query_trackers__', [])
tracker = []

if include_child_scopes:
trackers.append(tracker)
else:
trackers = [tracker]

with lf.context(__query_trackers__=trackers):
try:
yield tracker
finally:
pass
10 changes: 9 additions & 1 deletion langfun/core/structured/prompting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_call(self):
)
self.assertEqual(
prompting.query(
lf.Template('what is {{x}} + {{y}}'), int, x=1, y=0, lm=lm.clone()
lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
),
1,
)
Expand Down Expand Up @@ -945,5 +945,13 @@ def test_query(self):
)


class TrackQueriesTest(unittest.TestCase):

def test_basic(self):
pass

def test_concurrent_map(self):
pass

if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions langfun/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@ def from_value(
if isinstance(value, cls):
return value.clone(override=kwargs) if kwargs else value # pylint: disable=no-value-for-parameter
if isinstance(value, str):
template_vars = cls.resolve_vars(value)
kwargs = {k: v for k, v in kwargs.items() if k in template_vars}
return cls(template_str=value, **kwargs)
if isinstance(value, message_lib.Message):
kwargs.update(value.metadata)
Expand Down

0 comments on commit 5cbdc5a

Please sign in to comment.