diff --git a/langfun/__init__.py b/langfun/__init__.py index bad173a..be03f86 100644 --- a/langfun/__init__.py +++ b/langfun/__init__.py @@ -37,6 +37,9 @@ track_queries = structured.track_queries +# Helper function for map-reduce style querying. +query_and_reduce = structured.query_and_reduce + # Helper functions for input/output transformations based on # `lf.query` (e.g. jax-on-beam could use these for batch processing) query_prompt = structured.query_prompt diff --git a/langfun/core/structured/__init__.py b/langfun/core/structured/__init__.py index a32c207..5b964eb 100644 --- a/langfun/core/structured/__init__.py +++ b/langfun/core/structured/__init__.py @@ -56,6 +56,8 @@ from langfun.core.structured.querying import track_queries from langfun.core.structured.querying import QueryInvocation from langfun.core.structured.querying import query +from langfun.core.structured.querying import query_and_reduce + from langfun.core.structured.querying import query_prompt from langfun.core.structured.querying import query_output from langfun.core.structured.querying import query_reward diff --git a/langfun/core/structured/querying.py b/langfun/core/structured/querying.py index b85a632..8c10457 100644 --- a/langfun/core/structured/querying.py +++ b/langfun/core/structured/querying.py @@ -105,12 +105,11 @@ def _query_structure_cls( def query( prompt: Union[str, lf.Template, Any], - schema: Union[ - schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None - ] = None, + schema: schema_lib.SchemaType | None = None, default: Any = lf.RAISE_IF_HAS_ERROR, *, - lm: lf.LanguageModel | None = None, + lm: lf.LanguageModel | list[lf.LanguageModel] | None = None, + num_samples: int | list[int] = 1, examples: list[mapping.MappingExample] | None = None, cache_seed: int | None = 0, response_postprocess: Callable[[str], str] | None = None, @@ -121,35 +120,114 @@ def query( skip_lm: bool = False, **kwargs, ) -> Any: - """Queries an language model for a (maybe) structured output. + """Queries language model(s) for (maybe) structured outputs. + + `lf.query` is the single most important API of Langfun for querying LLMs. It + implements Object Oriented Programming (OOP) and supports the following + features: + + - **Input**: natural language (str), structured input (pg.Object instances) + and their mix (lf.Template instances), including modality objects. + - **Output**: structured output (when schema is not None) or natural language + (str). + - **Few-shot examples**: Structured few-shot prompting with `examples`. + - **Fan out multiple queries** in parallel to multiple language models, each + with multiple samples. Examples: + Case 1: Regular natural language-based LLM query: + + ``` + lf.query('1 + 1 = ?', lm=lf.llms.Gpt4Turbo()) + + # Answer: + # '2' + ``` + + Case 2: Query with structured output. + + ``` + lf.query('1 + 1 = ?', int, lm=lf.llms.Gpt4Turbo()) + + # Answer: + # 2 + ``` + + Case 3: Query with structured input. + + ``` + class Sum(pg.Object): + a: int + b: int + + lf.query(Sum(1, 1), int, lm=lf.llms.Gpt4Turbo()) + + # Answer: + # 2 ``` - class FlightDuration: - hours: int - minutes: int - - class Flight(pg.Object): - airline: str - flight_number: str - departure_airport_code: str - arrival_airport_code: str - departure_time: str - arrival_time: str - duration: FlightDuration - stops: int - price: float - - prompt = ''' - Information about flight UA2631. - ''' - - r = lf.query(prompt, Flight) - assert isinstance(r, Flight) - assert r.airline == 'United Airlines' - assert r.departure_airport_code == 'SFO' - assert r.duration.hour = 7 + + Case 4: Query with input of mixed modalities. + + ``` + class Animal(pg.Object): + pass + + class Dog(Animal): + pass + + class Entity(pg.Object): + name: str + + lf.query( + 'What is in this {{image}} and {{objects}}?' + list[Entity], + lm=lf.llms.Gpt4Turbo() + image=lf.Image(path='/path/to/a/airplane.png'), + objects=[Dog()], + ) + + # Answer: + # [Entity(name='airplane'), Entity(name='dog')] + ``` + + Case 5: Query with structured few-shot examples. + ``` + lf.query( + 'What is in this {{image}} and {{objects}}?' + list[Entity], + lm=lf.llms.Gpt4Turbo() + image=lf.Image(path='/path/to/a/dinasaur.png'), + objects=[Dog()], + examples=[ + lf.MappingExample( + input=lf.Template( + 'What is the object near the house in this {{image}}?', + image=lf.Image(path='/path/to/image.png'), + ), + schema=Entity, + output=Entity('cat'), + ), + ], + ) + + # Answer: + # [Entity(name='dinasaur'), Entity(name='dog')] + ``` + + Case 6: Multiple queries to multiple models. + ``` + lf.query( + '1 + 1 = ?', + int, + lm=[ + lf.llms.Gpt4Turbo(), + lf.llms.Gemini1_5Pro(), + ], + num_samples=[1, 2], + ) + # Answer: + # [2, 2, 2] ``` Args: @@ -160,7 +238,12 @@ class Flight(pg.Object): default: The default value if parsing failed. If not specified, error will be raised. lm: The language model to use. If not specified, the language model from - `lf.context` context manager will be used. + `lf.context` context manager will be used. If a list of language models + are provided, multiple queries will be issued in parallel to these models, + with each model returning samples specified by `num_samples`. + num_samples: The number of samples to return for each model. If a list of + integers are provided, the length must match the length of `lm`, + indicating the number of samples for each model. examples: An optional list of fewshot examples for helping parsing. If None, the default one-shot example will be added. cache_seed: Seed for computing cache key. The cache key is determined by a @@ -187,10 +270,61 @@ class Flight(pg.Object): - mapping_template: Change the template for each mapping examle. Returns: - The result based on the schema. + A single or a list of outputs (when `lm` is a list or `num_samples` is + greater than 1). Each output will be a `lf.Message` if `returns_message` is + set to True, or a instance specified by the `schema` otherwise. If no + `schema` is provided, the output will be a `str` in natural language. """ # Internal usage logging. + # Multiple quries will be issued when `lm` is a list or `num_samples` is + # greater than 1. + if isinstance(lm, list) or num_samples != 1: + def _single_query(inputs): + lm, example_i = inputs + return query( + prompt, + schema, + default=default, + lm=lm, + examples=examples, + # Usually num_examples should not be large, so we multiple the user + # provided cache seed by 100 to avoid collision. + cache_seed=( + None if cache_seed is None else cache_seed * 100 + example_i + ), + response_postprocess=response_postprocess, + autofix=autofix, + autofix_lm=autofix_lm, + protocol=protocol, + returns_message=returns_message, + skip_lm=skip_lm, + **kwargs, + ) + lm_list = lm if isinstance(lm, list) else [lm] + num_samples_list = ( + num_samples if isinstance(num_samples, list) + else [num_samples] * len(lm_list) + ) + assert len(lm_list) == len(num_samples_list), ( + 'Expect the length of `num_samples` to be the same as the ' + f'the length of `lm`. Got {num_samples} and {lm_list}.' + ) + query_inputs = [] + total_queries = 0 + for lm, num_samples in zip(lm_list, num_samples_list): + query_inputs.extend([(lm, i) for i in range(num_samples)]) + total_queries += num_samples + + samples = [] + for _, output, error in lf.concurrent_map( + _single_query, query_inputs, max_workers=max(64, total_queries), + ordered=True, + ): + if error is None: + samples.append(output) + return samples + # Normalize query schema. # When `lf.query` is used for symbolic completion, schema is automatically # inferred when it is None. @@ -280,11 +414,52 @@ def _result(message: lf.Message): return output_message if returns_message else _result(output_message) +# +# Helper function for map-reduce style querying. +# + + +def query_and_reduce( + prompt: Union[str, lf.Template, Any], + schema: schema_lib.SchemaType | None = None, + *, + reduce: Callable[[list[Any]], Any], + lm: lf.LanguageModel | list[lf.LanguageModel] | None = None, + num_samples: int | list[int] = 1, + **kwargs, +) -> Any: + """Issues multiple `lf.query` calls in parallel and reduce the outputs. + + Args: + prompt: A str (may contain {{}} as template) as natural language input, or a + `pg.Symbolic` object as structured input as prompt to LLM. + schema: A type annotation as the schema for output object. If str (default), + the response will be a str in natural language. + reduce: A function to reduce the outputs of multiple `lf.query` calls. It + takes a list of outputs and returns the final object. + lm: The language model to use. If not specified, the language model from + `lf.context` context manager will be used. + num_samples: The number of samples to obtain from each language model being + requested. If a list is provided, it should have the same length as `lm`. + **kwargs: Additional arguments to pass to `lf.query`. + + Returns: + The reduced output from multiple `lf.query` calls. + """ + results = query(prompt, schema, lm=lm, num_samples=num_samples, **kwargs) + if isinstance(results, list): + results = reduce(results) + return results + + +# +# Functions for decomposing `lf.query` into pre-llm and post-llm operations. +# + + def query_prompt( prompt: Union[str, lf.Template, Any], - schema: Union[ - schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None - ] = None, + schema: schema_lib.SchemaType | None = None, **kwargs, ) -> lf.Message: """Returns the final prompt sent to LLM for `lf.query`.""" @@ -295,9 +470,7 @@ def query_prompt( def query_output( response: Union[str, lf.Message], - schema: Union[ - schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None - ], + schema: schema_lib.SchemaType | None = None, **kwargs, ) -> Any: """Returns the final output of `lf.query` from a provided LLM response.""" @@ -308,6 +481,11 @@ def query_output( ) +# +# Functions for computing reward of an LLM response based on a mapping example. +# + + def query_reward( mapping_example: Union[str, mapping.MappingExample], response: Union[str, lf.Message], @@ -362,6 +540,11 @@ def _reward(self, input, expected_output, metadata): # pylint: disable=redefine return _reward +# +# Functions for tracking `lf.query` invocations. +# + + class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension): """A class to represent the invocation of `lf.query`.""" diff --git a/langfun/core/structured/querying_test.py b/langfun/core/structured/querying_test.py index ecc7678..8d32859 100644 --- a/langfun/core/structured/querying_test.py +++ b/langfun/core/structured/querying_test.py @@ -327,6 +327,69 @@ def test_structure_with_modality_and_examples_to_structure_render(self): expected_modalities=3, ) + def test_multiple_queries(self): + self.assertEqual( + querying.query( + 'Compute 1 + 2', + int, + lm=[ + fake.StaticResponse('1'), + fake.StaticResponse('2'), + ], + num_samples=[1, 2], + ), + [1, 2, 2] + ) + self.assertEqual( + querying.query( + 'Compute 1 + 2', + int, + lm=[ + fake.StaticResponse('1'), + fake.StaticResponse('2'), + ], + num_samples=2, + ), + [1, 1, 2, 2] + ) + self.assertEqual( + querying.query( + 'Compute 1 + 2', + int, + lm=[ + fake.StaticResponse('1'), + fake.StaticResponse('abc'), + ], + num_samples=[1, 2], + ), + [1] + ) + self.assertEqual( + querying.query( + 'Compute 1 + 2', + int, + default=0, + lm=[ + fake.StaticResponse('1'), + fake.StaticResponse('abc'), + ], + num_samples=[1, 2], + ), + [1, 0, 0] + ) + results = querying.query( + 'Compute 1 + 2', + int, + default=0, + lm=[ + fake.StaticResponse('1'), + fake.StaticResponse('abc'), + ], + returns_message=True, + ) + self.assertEqual([r.text for r in results], ['1', 'abc']) + self.assertEqual([r.result for r in results], [1, 0]) + def test_bad_protocol(self): with self.assertRaisesRegex(ValueError, 'Unknown protocol'): querying.query('what is 1 + 1', int, protocol='text') @@ -393,6 +456,30 @@ def test_query_prompt_with_unrooted_template(self): ) self.assertIsNotNone(output.get_modality('image')) + def test_query_and_reduce(self): + self.assertEqual( + querying.query_and_reduce( + 'Compute 1 + 1', + int, + reduce=sum, + lm=[ + fake.StaticResponse('1'), + fake.StaticResponse('2'), + ], + num_samples=[1, 2], + ), + 5 + ) + self.assertEqual( + querying.query_and_reduce( + 'Compute 1 + 1', + int, + reduce=sum, + lm=fake.StaticResponse('2'), + ), + 2 + ) + def test_query_output(self): self.assertEqual( querying.query_output( diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index c1677d9..6c7bc3c 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -227,6 +227,9 @@ def _html_tree_view_tooltip( ) +SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]] + + def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]: """Returns a list of top level value specs from a symbolic value.""" top_level_object_specs = []