Skip to content

Commit

Permalink
Introduce lf.track_queries.
Browse files Browse the repository at this point in the history
- Revise the input preparation logic of `lf.query`.
- Add `lf.QueryInvocation` to represent `lf.query` invocations.
- Add `lf.track_queries` context manager to track `lf.query` invocations.

PiperOrigin-RevId: 702220910
  • Loading branch information
daiyip authored and langfun authors committed Dec 3, 2024
1 parent c47d506 commit f06cff4
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 42 deletions.
2 changes: 2 additions & 0 deletions langfun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
score = structured.score
generate_class = structured.generate_class

track_queries = structured.track_queries

# Helper functions for input/output transformations based on
# `lf.query` (e.g. jax-on-beam could use these for batch processing)
query_prompt = structured.query_prompt
Expand Down
2 changes: 2 additions & 0 deletions langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
from langfun.core.structured.prompting import query_prompt
from langfun.core.structured.prompting import query_output
from langfun.core.structured.prompting import query_reward
from langfun.core.structured.prompting import QueryInvocation
from langfun.core.structured.prompting import track_queries

from langfun.core.structured.description import DescribeStructure
from langfun.core.structured.description import describe
Expand Down
175 changes: 134 additions & 41 deletions langfun/core/structured/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
"""Symbolic query."""

import contextlib
import functools
from typing import Any, Callable, Type, Union
from typing import Annotated, Any, Callable, Iterator, Type, Union

import langfun.core as lf
from langfun.core.llms import fake
Expand Down Expand Up @@ -102,7 +103,7 @@ def _query_structure_cls(


def query(
prompt: Union[str, pg.Symbolic],
prompt: Union[str, lf.Template, Any],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
] = None,
Expand All @@ -119,7 +120,7 @@ def query(
skip_lm: bool = False,
**kwargs,
) -> Any:
"""Parse a natural langugage message based on schema.
"""Queries an language model for a (maybe) structured output.
Examples:
Expand Down Expand Up @@ -189,55 +190,85 @@ class Flight(pg.Object):
"""
# Internal usage logging.

# Normalize query schema.
# When `lf.query` is used for symbolic completion, schema is automatically
# inferred when it is None.
if isinstance(prompt, pg.Symbolic) and prompt.sym_partial and schema is None:
schema = prompt.__class__

# Create a copy of the prompt if it has a parent object, so all child modality
# objects could be referred by path relative to the prompt.
if isinstance(prompt, lf.Template) and prompt.sym_parent:
prompt = prompt.clone()
# Normalize query input.
if isinstance(prompt, (lf.Message, str)):
# Query with structured output.
prompt_kwargs = kwargs.copy()
prompt_kwargs.pop('template_str', None)
query_input = lf.Template.from_value(prompt, **prompt_kwargs)
elif isinstance(prompt, lf.Template):
# Create a copy of the prompt if it has a parent object, so all child
# modality objects could be referred by path relative to the prompt.
query_input = prompt.clone() if prompt.sym_parent is not None else prompt

# Attach template metadata from kwargs. This is used to pass through fields
# from kwargs to the rendered message.
template_metadata = {k: v for k, v in kwargs if k.startswith('metadata_')}
query_input.rebind(
template_metadata, skip_notification=True, raise_on_no_change=False
)
elif pg.MISSING_VALUE == prompt:
query_input = lf.UserMessage('Unused prompt.')
else:
query_input = schema_lib.mark_missing(prompt)

if schema in (None, str):
# Query with natural language output.
output = lf.LangFunc.from_value(prompt, **kwargs)(
output_message = lf.LangFunc.from_value(query_input)(
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
)
if response_postprocess:
processed_text = response_postprocess(output.text)
if processed_text != output.text:
output = lf.AIMessage(processed_text, source=output)
return output if returns_message else output.text

# Query with structured output.
prompt_kwargs = kwargs.copy()

# NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
# QueryStructure template string. Therefore, we pop out the argument for
# prompt rendering.
prompt_kwargs.pop('template_str', None)

if isinstance(prompt, (str, lf.Message, lf.Template)):
prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm)
processed_text = response_postprocess(output_message.text)
if processed_text != output_message.text:
output_message = lf.AIMessage(processed_text, source=output_message)
else:
prompt = schema_lib.mark_missing(prompt)

output = _query_structure_cls(protocol)(
input=prompt,
schema=schema,
default=default,
examples=examples,
response_postprocess=response_postprocess,
autofix=autofix if protocol == 'python' else 0,
**kwargs,
)(
lm=lm,
autofix_lm=autofix_lm or lm,
cache_seed=cache_seed,
skip_lm=skip_lm,
)
return output if returns_message else output.result
# Query with structured output.
output_message = _query_structure_cls(protocol)(
input=(
query_input.render(lm=lm)
if isinstance(query_input, lf.Template)
else query_input
),
schema=schema,
default=default,
examples=examples,
response_postprocess=response_postprocess,
autofix=autofix if protocol == 'python' else 0,
**kwargs,
)(
lm=lm,
autofix_lm=autofix_lm or lm,
cache_seed=cache_seed,
skip_lm=skip_lm,
)

def _result(message: lf.Message):
return message.text if schema in (None, str) else message.result

# Track the query invocations.
if pg.MISSING_VALUE != prompt and not skip_lm:
trackers = lf.context_value('__query_trackers__', [])
if trackers:
invocation = QueryInvocation(
input=pg.Ref(query_input),
schema=(
schema_lib.Schema.from_value(schema)
if schema not in (None, str) else None
),
output=pg.Ref(_result(output_message)),
lm=pg.Ref(lm),
examples=pg.Ref(examples) if examples else [],
)
for i, (tracker, include_child_scopes) in enumerate(trackers):
if i == 0 or include_child_scopes:
tracker.append(invocation)
return output_message if returns_message else _result(output_message)


def query_prompt(
Expand All @@ -264,7 +295,7 @@ def query_output(
kwargs.pop('prompt', None)
kwargs.pop('lm', None)
return query(
'Unused prompt', schema, lm=fake.StaticResponse(response), **kwargs
pg.MISSING_VALUE, schema, lm=fake.StaticResponse(response), **kwargs
)


Expand Down Expand Up @@ -320,3 +351,65 @@ def _reward(self, input, expected_output, metadata): # pylint: disable=redefine
args = [self, input, expected_output, metadata]
return cls.__reward__(*args[:num_args])
return _reward


class QueryInvocation(pg.Object):
"""A class to represent the invocation of `lf.query`."""

input: Annotated[
Union[lf.Template, pg.Symbolic],
'Mapping input of `lf.query`.'
]
schema: pg.typing.Annotated[
schema_lib.schema_spec(noneable=True),
'Schema of `lf.query`.'
]
output: Annotated[
Any,
'Mapping output of `lf.query`.'
]
lm: Annotated[
lf.LanguageModel,
'Language model used for `lf.query`.'
]
examples: Annotated[
list[mapping.MappingExample],
'Fewshot exemplars for `lf.query`.'
] = []


@contextlib.contextmanager
def track_queries(
include_child_scopes: bool = True
) -> Iterator[list[QueryInvocation]]:
"""Track all queries made during the context.
Example:
```
with lf.track_queries() as queries:
lf.query('hi', lm=lm)
lf.query('What is this {{image}}?', lm=lm, image=image)
print(queries)
```
Args:
include_child_scopes: If True, the queries made in child scopes will be
included in the returned list. Otherwise, only the queries made in the
current scope will be included.
Yields:
A list of `QueryInvocation` objects representing the queries made during
the context.
"""
trackers = lf.context_value('__query_trackers__', [])
tracker = []

with lf.context(
__query_trackers__=[(tracker, include_child_scopes)] + trackers
):
try:
yield tracker
finally:
pass
66 changes: 65 additions & 1 deletion langfun/core/structured/prompting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_call(self):
)
self.assertEqual(
prompting.query(
lf.Template('what is {{x}} + {{y}}'), int, x=1, y=0, lm=lm.clone()
lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
),
1,
)
Expand Down Expand Up @@ -945,5 +945,69 @@ def test_query(self):
)


class TrackQueriesTest(unittest.TestCase):

def test_include_child_scopes(self):
lm = fake.StaticSequence([
'bar',
'Activity(description="hi")',
])
with prompting.track_queries() as queries:
prompting.query('foo', lm=lm)
with prompting.track_queries() as child_queries:
prompting.query('give me an activity', Activity, lm=lm)

self.assertEqual(len(queries), 2)
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
self.assertIsNone(queries[0].schema)
self.assertEqual(queries[0].output, 'bar')
self.assertIs(queries[0].lm, lm)

self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity')))
self.assertEqual(queries[1].schema.spec.cls, Activity)
self.assertTrue(pg.eq(queries[1].output, Activity(description='hi')))
self.assertIs(queries[1].lm, lm)

self.assertEqual(len(child_queries), 1)
self.assertIs(child_queries[0], queries[1])

def test_exclude_child_scopes(self):
lm = fake.StaticSequence([
'bar',
'Activity(description="hi")',
])
with prompting.track_queries(include_child_scopes=False) as queries:
prompting.query('foo', lm=lm)
with prompting.track_queries(include_child_scopes=False) as child_queries:
prompting.query('give me an activity', Activity, lm=lm)

self.assertEqual(len(queries), 1)
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
self.assertIsNone(queries[0].schema)
self.assertEqual(queries[0].output, 'bar')
self.assertIs(queries[0].lm, lm)

self.assertEqual(len(child_queries), 1)
self.assertTrue(
pg.eq(child_queries[0].input, lf.Template('give me an activity'))
)
self.assertEqual(child_queries[0].schema.spec.cls, Activity)
self.assertTrue(pg.eq(child_queries[0].output, Activity(description='hi')))
self.assertIs(child_queries[0].lm, lm)

def test_concurrent_map(self):

def make_query(prompt):
_ = prompting.query(prompt, lm=lm)

lm = fake.StaticSequence([
'foo',
'bar',
])
with prompting.track_queries() as queries:
list(lf.concurrent_map(make_query, ['a', 'b']))
self.assertEqual(len(queries), 2)


if __name__ == '__main__':
unittest.main()

0 comments on commit f06cff4

Please sign in to comment.