diff --git a/langfun/__init__.py b/langfun/__init__.py index 685ab3f..bad173a 100644 --- a/langfun/__init__.py +++ b/langfun/__init__.py @@ -35,6 +35,8 @@ score = structured.score generate_class = structured.generate_class +track_queries = structured.track_queries + # Helper functions for input/output transformations based on # `lf.query` (e.g. jax-on-beam could use these for batch processing) query_prompt = structured.query_prompt diff --git a/langfun/core/structured/__init__.py b/langfun/core/structured/__init__.py index 931a6e5..685b067 100644 --- a/langfun/core/structured/__init__.py +++ b/langfun/core/structured/__init__.py @@ -69,6 +69,8 @@ from langfun.core.structured.prompting import query_prompt from langfun.core.structured.prompting import query_output from langfun.core.structured.prompting import query_reward +from langfun.core.structured.prompting import QueryInvocation +from langfun.core.structured.prompting import track_queries from langfun.core.structured.description import DescribeStructure from langfun.core.structured.description import describe diff --git a/langfun/core/structured/prompting.py b/langfun/core/structured/prompting.py index 41d6f05..9860cb7 100644 --- a/langfun/core/structured/prompting.py +++ b/langfun/core/structured/prompting.py @@ -13,8 +13,9 @@ # limitations under the License. """Symbolic query.""" +import contextlib import functools -from typing import Any, Callable, Type, Union +from typing import Annotated, Any, Callable, Iterator, Type, Union import langfun.core as lf from langfun.core.llms import fake @@ -102,7 +103,7 @@ def _query_structure_cls( def query( - prompt: Union[str, pg.Symbolic], + prompt: Union[str, lf.Template, Any], schema: Union[ schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None ] = None, @@ -119,7 +120,7 @@ def query( skip_lm: bool = False, **kwargs, ) -> Any: - """Parse a natural langugage message based on schema. + """Queries an language model for a (maybe) structured output. Examples: @@ -189,55 +190,88 @@ class Flight(pg.Object): """ # Internal usage logging. + # Normalize query schema. # When `lf.query` is used for symbolic completion, schema is automatically # inferred when it is None. if isinstance(prompt, pg.Symbolic) and prompt.sym_partial and schema is None: schema = prompt.__class__ - # Create a copy of the prompt if it has a parent object, so all child modality - # objects could be referred by path relative to the prompt. - if isinstance(prompt, lf.Template) and prompt.sym_parent: - prompt = prompt.clone() - - if schema in (None, str): - # Query with natural language output. - output = lf.LangFunc.from_value(prompt, **kwargs)( - lm=lm, cache_seed=cache_seed, skip_lm=skip_lm + # Normalize query input. + if isinstance(prompt, (lf.Message, str)): + # Query with structured output. + prompt_kwargs = kwargs.copy() + prompt_kwargs.pop('template_str', None) + query_input = lf.Template.from_value(prompt, **prompt_kwargs) + elif isinstance(prompt, lf.Template): + # Create a copy of the prompt if it has a parent object, so all child + # modality objects could be referred by path relative to the prompt. + query_input = prompt.clone() if prompt.sym_parent is not None else prompt + + # Attach template metadata from kwargs. This is used to pass through fields + # from kwargs to the rendered message. + template_metadata = { + k: v for k, v in kwargs.items() if k.startswith('metadata_') + } + query_input.rebind( + template_metadata, skip_notification=True, raise_on_no_change=False ) - if response_postprocess: - processed_text = response_postprocess(output.text) - if processed_text != output.text: - output = lf.AIMessage(processed_text, source=output) - return output if returns_message else output.text - - # Query with structured output. - prompt_kwargs = kwargs.copy() - - # NOTE(daiyip): when `template_str` is passed in, it's intended to modify the - # QueryStructure template string. Therefore, we pop out the argument for - # prompt rendering. - prompt_kwargs.pop('template_str', None) - - if isinstance(prompt, (str, lf.Message, lf.Template)): - prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm) + elif pg.MISSING_VALUE == prompt: + query_input = lf.UserMessage('Unused prompt.') else: - prompt = schema_lib.mark_missing(prompt) - - output = _query_structure_cls(protocol)( - input=prompt, - schema=schema, - default=default, - examples=examples, - response_postprocess=response_postprocess, - autofix=autofix if protocol == 'python' else 0, - **kwargs, - )( - lm=lm, - autofix_lm=autofix_lm or lm, - cache_seed=cache_seed, - skip_lm=skip_lm, - ) - return output if returns_message else output.result + query_input = schema_lib.mark_missing(prompt) + + with lf.track_usages(): + if schema in (None, str): + # Query with natural language output. + output_message = lf.LangFunc.from_value(query_input)( + lm=lm, cache_seed=cache_seed, skip_lm=skip_lm + ) + if response_postprocess: + processed_text = response_postprocess(output_message.text) + if processed_text != output_message.text: + output_message = lf.AIMessage(processed_text, source=output_message) + else: + # Query with structured output. + output_message = _query_structure_cls(protocol)( + input=( + query_input.render(lm=lm) + if isinstance(query_input, lf.Template) + else query_input + ), + schema=schema, + default=default, + examples=examples, + response_postprocess=response_postprocess, + autofix=autofix if protocol == 'python' else 0, + **kwargs, + )( + lm=lm, + autofix_lm=autofix_lm or lm, + cache_seed=cache_seed, + skip_lm=skip_lm, + ) + + def _result(message: lf.Message): + return message.text if schema in (None, str) else message.result + + # Track the query invocations. + if pg.MISSING_VALUE != prompt and not skip_lm: + trackers = lf.context_value('__query_trackers__', []) + if trackers: + invocation = QueryInvocation( + input=pg.Ref(query_input), + schema=( + schema_lib.Schema.from_value(schema) + if schema not in (None, str) else None + ), + output=pg.Ref(_result(output_message)), + lm=pg.Ref(lm), + examples=pg.Ref(examples) if examples else [], + ) + for i, (tracker, include_child_scopes) in enumerate(trackers): + if i == 0 or include_child_scopes: + tracker.append(invocation) + return output_message if returns_message else _result(output_message) def query_prompt( @@ -264,7 +298,7 @@ def query_output( kwargs.pop('prompt', None) kwargs.pop('lm', None) return query( - 'Unused prompt', schema, lm=fake.StaticResponse(response), **kwargs + pg.MISSING_VALUE, schema, lm=fake.StaticResponse(response), **kwargs ) @@ -320,3 +354,65 @@ def _reward(self, input, expected_output, metadata): # pylint: disable=redefine args = [self, input, expected_output, metadata] return cls.__reward__(*args[:num_args]) return _reward + + +class QueryInvocation(pg.Object): + """A class to represent the invocation of `lf.query`.""" + + input: Annotated[ + Union[lf.Template, pg.Symbolic], + 'Mapping input of `lf.query`.' + ] + schema: pg.typing.Annotated[ + schema_lib.schema_spec(noneable=True), + 'Schema of `lf.query`.' + ] + output: Annotated[ + Any, + 'Mapping output of `lf.query`.' + ] + lm: Annotated[ + lf.LanguageModel, + 'Language model used for `lf.query`.' + ] + examples: Annotated[ + list[mapping.MappingExample], + 'Fewshot exemplars for `lf.query`.' + ] = [] + + +@contextlib.contextmanager +def track_queries( + include_child_scopes: bool = True +) -> Iterator[list[QueryInvocation]]: + """Track all queries made during the context. + + Example: + + ``` + with lf.track_queries() as queries: + lf.query('hi', lm=lm) + lf.query('What is this {{image}}?', lm=lm, image=image) + + print(queries) + ``` + + Args: + include_child_scopes: If True, the queries made in child scopes will be + included in the returned list. Otherwise, only the queries made in the + current scope will be included. + + Yields: + A list of `QueryInvocation` objects representing the queries made during + the context. + """ + trackers = lf.context_value('__query_trackers__', []) + tracker = [] + + with lf.context( + __query_trackers__=[(tracker, include_child_scopes)] + trackers + ): + try: + yield tracker + finally: + pass diff --git a/langfun/core/structured/prompting_test.py b/langfun/core/structured/prompting_test.py index 6d7eae8..1d99119 100644 --- a/langfun/core/structured/prompting_test.py +++ b/langfun/core/structured/prompting_test.py @@ -89,7 +89,7 @@ def test_call(self): ) self.assertEqual( prompting.query( - lf.Template('what is {{x}} + {{y}}'), int, x=1, y=0, lm=lm.clone() + lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone() ), 1, ) @@ -945,5 +945,69 @@ def test_query(self): ) +class TrackQueriesTest(unittest.TestCase): + + def test_include_child_scopes(self): + lm = fake.StaticSequence([ + 'bar', + 'Activity(description="hi")', + ]) + with prompting.track_queries() as queries: + prompting.query('foo', lm=lm) + with prompting.track_queries() as child_queries: + prompting.query('give me an activity', Activity, lm=lm) + + self.assertEqual(len(queries), 2) + self.assertTrue(pg.eq(queries[0].input, lf.Template('foo'))) + self.assertIsNone(queries[0].schema) + self.assertEqual(queries[0].output, 'bar') + self.assertIs(queries[0].lm, lm) + + self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity'))) + self.assertEqual(queries[1].schema.spec.cls, Activity) + self.assertTrue(pg.eq(queries[1].output, Activity(description='hi'))) + self.assertIs(queries[1].lm, lm) + + self.assertEqual(len(child_queries), 1) + self.assertIs(child_queries[0], queries[1]) + + def test_exclude_child_scopes(self): + lm = fake.StaticSequence([ + 'bar', + 'Activity(description="hi")', + ]) + with prompting.track_queries(include_child_scopes=False) as queries: + prompting.query('foo', lm=lm) + with prompting.track_queries(include_child_scopes=False) as child_queries: + prompting.query('give me an activity', Activity, lm=lm) + + self.assertEqual(len(queries), 1) + self.assertTrue(pg.eq(queries[0].input, lf.Template('foo'))) + self.assertIsNone(queries[0].schema) + self.assertEqual(queries[0].output, 'bar') + self.assertIs(queries[0].lm, lm) + + self.assertEqual(len(child_queries), 1) + self.assertTrue( + pg.eq(child_queries[0].input, lf.Template('give me an activity')) + ) + self.assertEqual(child_queries[0].schema.spec.cls, Activity) + self.assertTrue(pg.eq(child_queries[0].output, Activity(description='hi'))) + self.assertIs(child_queries[0].lm, lm) + + def test_concurrent_map(self): + + def make_query(prompt): + _ = prompting.query(prompt, lm=lm) + + lm = fake.StaticSequence([ + 'foo', + 'bar', + ]) + with prompting.track_queries() as queries: + list(lf.concurrent_map(make_query, ['a', 'b'])) + self.assertEqual(len(queries), 2) + + if __name__ == '__main__': unittest.main()