Skip to content

Commit

Permalink
Multiple sampling support through lf.query and add `lf.query_and_re…
Browse files Browse the repository at this point in the history
…duce`.

With this CL, `lf.query` now supports issuing parallel queries to multiple LLMs with multiple samples (through accepting list for `lm` and the `num_samples` argument).

Additionally, we introduce `lf.query_and_reduce` to perform map-reduce style operation with `lf.query`. The motivation of introducing `lf.query_and_reduce` is to maintain an 1:1 mapping at the interface level, while allowing users to use the `reduce` function to perform regularizations on LLM outputs. `lf.query_and_reduce` is a drop-in replacement for `lf.query` with compatible arguments and default behavior.

Example usage: sampling Gpt4 and Gemini2 each with 5 samples and obtain the final output through majority voting.

```python
lf.query_and_reduce(
  'compute 256 * 345',
  int,
  reduce=lambda answers: lf.query('Find majority from {{answers}}', int, lm=lf.llms.Gpt4(), answers=answers),
  lm=[lf.llms.Gpt4(), lf.llms.Gemini2()],
  num_samples=5,
  ...
)
```

PiperOrigin-RevId: 701557382
  • Loading branch information
daiyip authored and langfun authors committed Dec 17, 2024
1 parent 371a8ef commit 44ea310
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 37 deletions.
3 changes: 3 additions & 0 deletions langfun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

track_queries = structured.track_queries

# Helper function for map-reduce style querying.
query_and_reduce = structured.query_and_reduce

# Helper functions for input/output transformations based on
# `lf.query` (e.g. jax-on-beam could use these for batch processing)
query_prompt = structured.query_prompt
Expand Down
2 changes: 2 additions & 0 deletions langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
from langfun.core.structured.querying import track_queries
from langfun.core.structured.querying import QueryInvocation
from langfun.core.structured.querying import query
from langfun.core.structured.querying import query_and_reduce

from langfun.core.structured.querying import query_prompt
from langfun.core.structured.querying import query_output
from langfun.core.structured.querying import query_reward
Expand Down
257 changes: 220 additions & 37 deletions langfun/core/structured/querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,11 @@ def _query_structure_cls(

def query(
prompt: Union[str, lf.Template, Any],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
] = None,
schema: schema_lib.SchemaType | None = None,
default: Any = lf.RAISE_IF_HAS_ERROR,
*,
lm: lf.LanguageModel | None = None,
lm: lf.LanguageModel | list[lf.LanguageModel] | None = None,
num_samples: int | list[int] = 1,
examples: list[mapping.MappingExample] | None = None,
cache_seed: int | None = 0,
response_postprocess: Callable[[str], str] | None = None,
Expand All @@ -121,35 +120,114 @@ def query(
skip_lm: bool = False,
**kwargs,
) -> Any:
"""Queries an language model for a (maybe) structured output.
"""Queries language model(s) for (maybe) structured outputs.
`lf.query` is the single most important API of Langfun for querying LLMs. It
implements Object Oriented Programming (OOP) and supports the following
features:
- **Input**: natural language (str), structured input (pg.Object instances)
and their mix (lf.Template instances), including modality objects.
- **Output**: structured output (when schema is not None) or natural language
(str).
- **Few-shot examples**: Structured few-shot prompting with `examples`.
- **Fan out multiple queries** in parallel to multiple language models, each
with multiple samples.
Examples:
Case 1: Regular natural language-based LLM query:
```
lf.query('1 + 1 = ?', lm=lf.llms.Gpt4Turbo())
# Answer:
# '2'
```
Case 2: Query with structured output.
```
lf.query('1 + 1 = ?', int, lm=lf.llms.Gpt4Turbo())
# Answer:
# 2
```
Case 3: Query with structured input.
```
class Sum(pg.Object):
a: int
b: int
lf.query(Sum(1, 1), int, lm=lf.llms.Gpt4Turbo())
# Answer:
# 2
```
class FlightDuration:
hours: int
minutes: int
class Flight(pg.Object):
airline: str
flight_number: str
departure_airport_code: str
arrival_airport_code: str
departure_time: str
arrival_time: str
duration: FlightDuration
stops: int
price: float
prompt = '''
Information about flight UA2631.
'''
r = lf.query(prompt, Flight)
assert isinstance(r, Flight)
assert r.airline == 'United Airlines'
assert r.departure_airport_code == 'SFO'
assert r.duration.hour = 7
Case 4: Query with input of mixed modalities.
```
class Animal(pg.Object):
pass
class Dog(Animal):
pass
class Entity(pg.Object):
name: str
lf.query(
'What is in this {{image}} and {{objects}}?'
list[Entity],
lm=lf.llms.Gpt4Turbo()
image=lf.Image(path='/path/to/a/airplane.png'),
objects=[Dog()],
)
# Answer:
# [Entity(name='airplane'), Entity(name='dog')]
```
Case 5: Query with structured few-shot examples.
```
lf.query(
'What is in this {{image}} and {{objects}}?'
list[Entity],
lm=lf.llms.Gpt4Turbo()
image=lf.Image(path='/path/to/a/dinasaur.png'),
objects=[Dog()],
examples=[
lf.MappingExample(
input=lf.Template(
'What is the object near the house in this {{image}}?',
image=lf.Image(path='/path/to/image.png'),
),
schema=Entity,
output=Entity('cat'),
),
],
)
# Answer:
# [Entity(name='dinasaur'), Entity(name='dog')]
```
Case 6: Multiple queries to multiple models.
```
lf.query(
'1 + 1 = ?',
int,
lm=[
lf.llms.Gpt4Turbo(),
lf.llms.Gemini1_5Pro(),
],
num_samples=[1, 2],
)
# Answer:
# [2, 2, 2]
```
Args:
Expand All @@ -160,7 +238,12 @@ class Flight(pg.Object):
default: The default value if parsing failed. If not specified, error will
be raised.
lm: The language model to use. If not specified, the language model from
`lf.context` context manager will be used.
`lf.context` context manager will be used. If a list of language models
are provided, multiple queries will be issued in parallel to these models,
with each model returning samples specified by `num_samples`.
num_samples: The number of samples to return for each model. If a list of
integers are provided, the length must match the length of `lm`,
indicating the number of samples for each model.
examples: An optional list of fewshot examples for helping parsing. If None,
the default one-shot example will be added.
cache_seed: Seed for computing cache key. The cache key is determined by a
Expand All @@ -187,10 +270,61 @@ class Flight(pg.Object):
- mapping_template: Change the template for each mapping examle.
Returns:
The result based on the schema.
A single or a list of outputs (when `lm` is a list or `num_samples` is
greater than 1). Each output will be a `lf.Message` if `returns_message` is
set to True, or a instance specified by the `schema` otherwise. If no
`schema` is provided, the output will be a `str` in natural language.
"""
# Internal usage logging.

# Multiple quries will be issued when `lm` is a list or `num_samples` is
# greater than 1.
if isinstance(lm, list) or num_samples != 1:
def _single_query(inputs):
lm, example_i = inputs
return query(
prompt,
schema,
default=default,
lm=lm,
examples=examples,
# Usually num_examples should not be large, so we multiple the user
# provided cache seed by 100 to avoid collision.
cache_seed=(
None if cache_seed is None else cache_seed * 100 + example_i
),
response_postprocess=response_postprocess,
autofix=autofix,
autofix_lm=autofix_lm,
protocol=protocol,
returns_message=returns_message,
skip_lm=skip_lm,
**kwargs,
)
lm_list = lm if isinstance(lm, list) else [lm]
num_samples_list = (
num_samples if isinstance(num_samples, list)
else [num_samples] * len(lm_list)
)
assert len(lm_list) == len(num_samples_list), (
'Expect the length of `num_samples` to be the same as the '
f'the length of `lm`. Got {num_samples} and {lm_list}.'
)
query_inputs = []
total_queries = 0
for lm, num_samples in zip(lm_list, num_samples_list):
query_inputs.extend([(lm, i) for i in range(num_samples)])
total_queries += num_samples

samples = []
for _, output, error in lf.concurrent_map(
_single_query, query_inputs, max_workers=max(64, total_queries),
ordered=True,
):
if error is None:
samples.append(output)
return samples

# Normalize query schema.
# When `lf.query` is used for symbolic completion, schema is automatically
# inferred when it is None.
Expand Down Expand Up @@ -280,11 +414,52 @@ def _result(message: lf.Message):
return output_message if returns_message else _result(output_message)


#
# Helper function for map-reduce style querying.
#


def query_and_reduce(
prompt: Union[str, lf.Template, Any],
schema: schema_lib.SchemaType | None = None,
*,
reduce: Callable[[list[Any]], Any],
lm: lf.LanguageModel | list[lf.LanguageModel] | None = None,
num_samples: int | list[int] = 1,
**kwargs,
) -> Any:
"""Issues multiple `lf.query` calls in parallel and reduce the outputs.
Args:
prompt: A str (may contain {{}} as template) as natural language input, or a
`pg.Symbolic` object as structured input as prompt to LLM.
schema: A type annotation as the schema for output object. If str (default),
the response will be a str in natural language.
reduce: A function to reduce the outputs of multiple `lf.query` calls. It
takes a list of outputs and returns the final object.
lm: The language model to use. If not specified, the language model from
`lf.context` context manager will be used.
num_samples: The number of samples to obtain from each language model being
requested. If a list is provided, it should have the same length as `lm`.
**kwargs: Additional arguments to pass to `lf.query`.
Returns:
The reduced output from multiple `lf.query` calls.
"""
results = query(prompt, schema, lm=lm, num_samples=num_samples, **kwargs)
if isinstance(results, list):
results = reduce(results)
return results


#
# Functions for decomposing `lf.query` into pre-llm and post-llm operations.
#


def query_prompt(
prompt: Union[str, lf.Template, Any],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
] = None,
schema: schema_lib.SchemaType | None = None,
**kwargs,
) -> lf.Message:
"""Returns the final prompt sent to LLM for `lf.query`."""
Expand All @@ -295,9 +470,7 @@ def query_prompt(

def query_output(
response: Union[str, lf.Message],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
],
schema: schema_lib.SchemaType | None = None,
**kwargs,
) -> Any:
"""Returns the final output of `lf.query` from a provided LLM response."""
Expand All @@ -308,6 +481,11 @@ def query_output(
)


#
# Functions for computing reward of an LLM response based on a mapping example.
#


def query_reward(
mapping_example: Union[str, mapping.MappingExample],
response: Union[str, lf.Message],
Expand Down Expand Up @@ -362,6 +540,11 @@ def _reward(self, input, expected_output, metadata): # pylint: disable=redefine
return _reward


#
# Functions for tracking `lf.query` invocations.
#


class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
"""A class to represent the invocation of `lf.query`."""

Expand Down
Loading

0 comments on commit 44ea310

Please sign in to comment.