Skip to content

Commit

Permalink
Introduce lf.track_queries.
Browse files Browse the repository at this point in the history
- Revise the input preparation logic of `lf.query`.
- Add `lf.QueryInvocation` to represent `lf.query` invocations.
- Add `lf.track_queries` context manager to track `lf.query` invocations.

PiperOrigin-RevId: 702220910
  • Loading branch information
daiyip authored and langfun authors committed Dec 3, 2024
1 parent c47d506 commit 6a9d006
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 48 deletions.
2 changes: 2 additions & 0 deletions langfun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
score = structured.score
generate_class = structured.generate_class

track_queries = structured.track_queries

# Helper functions for input/output transformations based on
# `lf.query` (e.g. jax-on-beam could use these for batch processing)
query_prompt = structured.query_prompt
Expand Down
2 changes: 2 additions & 0 deletions langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
from langfun.core.structured.prompting import query_prompt
from langfun.core.structured.prompting import query_output
from langfun.core.structured.prompting import query_reward
from langfun.core.structured.prompting import QueryInvocation
from langfun.core.structured.prompting import track_queries

from langfun.core.structured.description import DescribeStructure
from langfun.core.structured.description import describe
Expand Down
195 changes: 148 additions & 47 deletions langfun/core/structured/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
"""Symbolic query."""

import contextlib
import functools
from typing import Any, Callable, Type, Union
from typing import Annotated, Any, Callable, Iterator, Type, Union

import langfun.core as lf
from langfun.core.llms import fake
Expand Down Expand Up @@ -102,7 +103,7 @@ def _query_structure_cls(


def query(
prompt: Union[str, pg.Symbolic],
prompt: Union[str, lf.Template, Any],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
] = None,
Expand All @@ -119,7 +120,7 @@ def query(
skip_lm: bool = False,
**kwargs,
) -> Any:
"""Parse a natural langugage message based on schema.
"""Queries an language model for a (maybe) structured output.
Examples:
Expand Down Expand Up @@ -189,59 +190,93 @@ class Flight(pg.Object):
"""
# Internal usage logging.

# Normalize query schema.
# When `lf.query` is used for symbolic completion, schema is automatically
# inferred when it is None.
if isinstance(prompt, pg.Symbolic) and prompt.sym_partial and schema is None:
schema = prompt.__class__

# Create a copy of the prompt if it has a parent object, so all child modality
# objects could be referred by path relative to the prompt.
if isinstance(prompt, lf.Template) and prompt.sym_parent:
prompt = prompt.clone()

if schema in (None, str):
# Query with natural language output.
output = lf.LangFunc.from_value(prompt, **kwargs)(
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
# Normalize query input.
if isinstance(prompt, (lf.Message, str)):
# Query with structured output.
prompt_kwargs = kwargs.copy()
prompt_kwargs.pop('template_str', None)
query_input = lf.Template.from_value(prompt, **prompt_kwargs)
elif isinstance(prompt, lf.Template):
# Create a copy of the prompt if it has a parent object, so all child
# modality objects could be referred by path relative to the prompt.
query_input = prompt.clone() if prompt.sym_parent is not None else prompt

# Attach template metadata from kwargs. This is used to pass through fields
# from kwargs to the rendered message.
template_metadata = {
k: v for k, v in kwargs.items() if k.startswith('metadata_')
}
query_input.rebind(
template_metadata, skip_notification=True, raise_on_no_change=False
)
if response_postprocess:
processed_text = response_postprocess(output.text)
if processed_text != output.text:
output = lf.AIMessage(processed_text, source=output)
return output if returns_message else output.text

# Query with structured output.
prompt_kwargs = kwargs.copy()

# NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
# QueryStructure template string. Therefore, we pop out the argument for
# prompt rendering.
prompt_kwargs.pop('template_str', None)

if isinstance(prompt, (str, lf.Message, lf.Template)):
prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm)
elif pg.MISSING_VALUE == prompt:
query_input = lf.UserMessage('')
else:
prompt = schema_lib.mark_missing(prompt)

output = _query_structure_cls(protocol)(
input=prompt,
schema=schema,
default=default,
examples=examples,
response_postprocess=response_postprocess,
autofix=autofix if protocol == 'python' else 0,
**kwargs,
)(
lm=lm,
autofix_lm=autofix_lm or lm,
cache_seed=cache_seed,
skip_lm=skip_lm,
)
return output if returns_message else output.result
query_input = schema_lib.mark_missing(prompt)

with lf.track_usages() as usage_summary:
if schema in (None, str):
# Query with natural language output.
output_message = lf.LangFunc.from_value(query_input, **kwargs)(
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
)
if response_postprocess:
processed_text = response_postprocess(output_message.text)
if processed_text != output_message.text:
output_message = lf.AIMessage(processed_text, source=output_message)
else:
# Query with structured output.
output_message = _query_structure_cls(protocol)(
input=(
query_input.render(lm=lm)
if isinstance(query_input, lf.Template)
else query_input
),
schema=schema,
default=default,
examples=examples,
response_postprocess=response_postprocess,
autofix=autofix if protocol == 'python' else 0,
**kwargs,
)(
lm=lm,
autofix_lm=autofix_lm or lm,
cache_seed=cache_seed,
skip_lm=skip_lm,
)

def _result(message: lf.Message):
return message.text if schema in (None, str) else message.result

# Track the query invocations.
if pg.MISSING_VALUE != prompt and not skip_lm:
trackers = lf.context_value('__query_trackers__', [])
if trackers:
invocation = QueryInvocation(
input=pg.Ref(query_input),
schema=(
schema_lib.Schema.from_value(schema)
if schema not in (None, str) else None
),
output=pg.Ref(_result(output_message)),
lm=pg.Ref(lm),
examples=pg.Ref(examples) if examples else [],
usage_summary=usage_summary,
)
for i, (tracker, include_child_scopes) in enumerate(trackers):
if i == 0 or include_child_scopes:
tracker.append(invocation)
return output_message if returns_message else _result(output_message)


def query_prompt(
prompt: Union[str, pg.Symbolic],
prompt: Union[str, lf.Template, Any],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
] = None,
Expand All @@ -264,7 +299,7 @@ def query_output(
kwargs.pop('prompt', None)
kwargs.pop('lm', None)
return query(
'Unused prompt', schema, lm=fake.StaticResponse(response), **kwargs
pg.MISSING_VALUE, schema, lm=fake.StaticResponse(response), **kwargs
)


Expand Down Expand Up @@ -320,3 +355,69 @@ def _reward(self, input, expected_output, metadata): # pylint: disable=redefine
args = [self, input, expected_output, metadata]
return cls.__reward__(*args[:num_args])
return _reward


class QueryInvocation(pg.Object):
"""A class to represent the invocation of `lf.query`."""

input: Annotated[
Union[lf.Template, pg.Symbolic],
'Mapping input of `lf.query`.'
]
schema: pg.typing.Annotated[
schema_lib.schema_spec(noneable=True),
'Schema of `lf.query`.'
]
output: Annotated[
Any,
'Mapping output of `lf.query`.'
]
lm: Annotated[
lf.LanguageModel,
'Language model used for `lf.query`.'
]
examples: Annotated[
list[mapping.MappingExample],
'Fewshot exemplars for `lf.query`.'
]
usage_summary: Annotated[
lf.UsageSummary,
'Usage summary for `lf.query`.'
]


@contextlib.contextmanager
def track_queries(
include_child_scopes: bool = True
) -> Iterator[list[QueryInvocation]]:
"""Track all queries made during the context.
Example:
```
with lf.track_queries() as queries:
lf.query('hi', lm=lm)
lf.query('What is this {{image}}?', lm=lm, image=image)
print(queries)
```
Args:
include_child_scopes: If True, the queries made in child scopes will be
included in the returned list. Otherwise, only the queries made in the
current scope will be included.
Yields:
A list of `QueryInvocation` objects representing the queries made during
the context.
"""
trackers = lf.context_value('__query_trackers__', [])
tracker = []

with lf.context(
__query_trackers__=[(tracker, include_child_scopes)] + trackers
):
try:
yield tracker
finally:
pass
85 changes: 84 additions & 1 deletion langfun/core/structured/prompting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_call(self):
)
self.assertEqual(
prompting.query(
lf.Template('what is {{x}} + {{y}}'), int, x=1, y=0, lm=lm.clone()
lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
),
1,
)
Expand Down Expand Up @@ -365,6 +365,23 @@ class Answer:
"""),
)

def test_query_prompt_with_metadata(self):
self.assertIn(
'x',
prompting.query_prompt(
'what is this?',
metadata_x=1
).metadata
)
self.assertIn(
'x',
prompting.query_prompt(
'what is this?',
int,
metadata_x=1
).metadata
)

def test_query_prompt_with_unrooted_template(self):
output = prompting.query_prompt(
pg.Dict(
Expand Down Expand Up @@ -945,5 +962,71 @@ def test_query(self):
)


class TrackQueriesTest(unittest.TestCase):

def test_include_child_scopes(self):
lm = fake.StaticSequence([
'bar',
'Activity(description="hi")',
])
with prompting.track_queries() as queries:
prompting.query('foo', lm=lm)
with prompting.track_queries() as child_queries:
prompting.query('give me an activity', Activity, lm=lm)

self.assertEqual(len(queries), 2)
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
self.assertIsNone(queries[0].schema)
self.assertEqual(queries[0].output, 'bar')
self.assertIs(queries[0].lm, lm)

self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity')))
self.assertEqual(queries[1].schema.spec.cls, Activity)
self.assertTrue(pg.eq(queries[1].output, Activity(description='hi')))
self.assertIs(queries[1].lm, lm)
self.assertGreater(queries[0].usage_summary.total.total_tokens, 0)
self.assertGreater(queries[1].usage_summary.total.total_tokens, 0)

self.assertEqual(len(child_queries), 1)
self.assertIs(child_queries[0], queries[1])

def test_exclude_child_scopes(self):
lm = fake.StaticSequence([
'bar',
'Activity(description="hi")',
])
with prompting.track_queries(include_child_scopes=False) as queries:
prompting.query('foo', lm=lm)
with prompting.track_queries(include_child_scopes=False) as child_queries:
prompting.query('give me an activity', Activity, lm=lm)

self.assertEqual(len(queries), 1)
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
self.assertIsNone(queries[0].schema)
self.assertEqual(queries[0].output, 'bar')
self.assertIs(queries[0].lm, lm)

self.assertEqual(len(child_queries), 1)
self.assertTrue(
pg.eq(child_queries[0].input, lf.Template('give me an activity'))
)
self.assertEqual(child_queries[0].schema.spec.cls, Activity)
self.assertTrue(pg.eq(child_queries[0].output, Activity(description='hi')))
self.assertIs(child_queries[0].lm, lm)

def test_concurrent_map(self):

def make_query(prompt):
_ = prompting.query(prompt, lm=lm)

lm = fake.StaticSequence([
'foo',
'bar',
])
with prompting.track_queries() as queries:
list(lf.concurrent_map(make_query, ['a', 'b']))
self.assertEqual(len(queries), 2)


if __name__ == '__main__':
unittest.main()

0 comments on commit 6a9d006

Please sign in to comment.