Skip to content

Commit

Permalink
Add lf.query_with_consistency.
Browse files Browse the repository at this point in the history
`lf.query_with_consistency` is a drop-in replacement for `lf.query` with compatible arguments and default behavior. It allows users to apply consistency methods to `lf.query` with different models and number of samples. The default voting function is LLM-based majority voting, users could plugin their own voting logic through the `vote_fn` argument.

Example usage: sampling Gpt4 and Gemini2 each with 5 samples and obtain the final output through majority voting.

```python
lf.query_with_consistency(
  'compute 256 * 345',
  int,
  lm=[lf.llms.Gpt4(), lf.llms.Gemini2()],
  num_samples=5,
  ...
)
```

PiperOrigin-RevId: 701557382
  • Loading branch information
daiyip authored and langfun authors committed Dec 17, 2024
1 parent 371a8ef commit 95cb9b9
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 9 deletions.
2 changes: 2 additions & 0 deletions langfun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

track_queries = structured.track_queries

query_with_consistency = structured.query_with_consistency

# Helper functions for input/output transformations based on
# `lf.query` (e.g. jax-on-beam could use these for batch processing)
query_prompt = structured.query_prompt
Expand Down
2 changes: 2 additions & 0 deletions langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from langfun.core.structured.querying import query_output
from langfun.core.structured.querying import query_reward

from langfun.core.structured.consistency import query_with_consistency

from langfun.core.structured.description import describe
from langfun.core.structured.completion import complete

Expand Down
147 changes: 147 additions & 0 deletions langfun/core/structured/consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2024 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""`lf.query` with consistency."""

from typing import Any, Callable, Type, Union

import langfun.core as lf
from langfun.core.structured import mapping
from langfun.core.structured import querying
from langfun.core.structured import schema as schema_lib
import pyglove as pg


def majority_voting(
outputs: list[Any],
schema: schema_lib.SchemaType | None = None,
lm: lf.LanguageModel | None = None,
) -> Any:
return querying.query(
prompt=(
'Derive an object from the following objects based on majority '
'voting: {{outputs}}.'
),
schema=schema,
outputs=outputs,
lm=lm,
)


ConsistencyFn = Union[
# Signature: `fn(outputs) -> output`
Callable[[list[Any]], Any],
# Signature: `fn(outputs, schema) -> output`
Callable[[list[Any], schema_lib.SchemaType | None], Any],
# Signature: `fn(outputs, schema, lm) -> output`
Callable[
[list[Any], schema_lib.SchemaType | None, lf.LanguageModel | None], Any
]
]


def query_with_consistency(
prompt: Union[str, lf.Template, Any],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
] = None,
default: Any = lf.RAISE_IF_HAS_ERROR,
*,
vote_fn: ConsistencyFn = majority_voting,
lm: lf.LanguageModel | list[lf.LanguageModel] | None = None,
consistency_lm: lf.LanguageModel | None = None,
examples: list[mapping.MappingExample] | None = None,
cache_seed: int | None = 0,
num_samples: int = 5,
**kwargs,
) -> Any:
"""`lf.query` with consistency.
This function is a wrapper around `lf.query` to apply consistency among the
return values of multiple calls to `lf.query`. It takes a list of language
models as input, and returns the final object by applying a voting
function. The voting function takes a list of outputs and a schema as
arguments, and returns the final object with consistency applied.
Args:
prompt: A str (may contain {{}} as template) as natural language input, or a
`pg.Symbolic` object as structured input as prompt to LLM.
schema: A type annotation as the schema for output object. If str (default),
the response will be a str in natural language.
default: The default value if parsing failed. If not specified, error will
be raised.
vote_fn: A function to vote for the final output among the return values of
individual calls to `lf.query`. It takes a list of outputs and a schema as
arguments, returns the final object with consistency applied.
lm: The language model to use. If not specified, the language model from
`lf.context` context manager will be used.
consistency_lm: The language model to use for consistency. If None, `lm`
will be used.
examples: An optional list of fewshot examples for helping parsing. If None,
the default one-shot example will be added.
cache_seed: The seed for the cache.
num_samples: The number of samples to obtain from each language model being
requested.
**kwargs: Additional arguments to pass to `lf.query`.
Returns:
The final object for the requested schema, with consistency applied.
"""
def query(inputs):
lm, example_i = inputs
return querying.query(
prompt, schema, lm=lm, examples=examples,
# Usually num_examples should not be large, so we multiple the user
# provided cache seed by 100 to avoid collision.
cache_seed=None if cache_seed is None else cache_seed * 100 + example_i,
**kwargs,
)

if not isinstance(lm, list):
lm_list = [lm]
else:
lm_list = lm

query_inputs = []
for lm in lm_list:
query_inputs.extend([(lm, i) for i in range(num_samples)])

# Concurrently sample the outputs from the language models.
samples = []
last_error = None
for _, output, error in lf.concurrent_map(
query, query_inputs, max_workers=max(64, len(lm_list) * num_samples),
silence_on_errors=mapping.MappingError
):
if error is None:
samples.append(output)
else:
last_error = error

if not samples:
if default is not lf.RAISE_IF_HAS_ERROR:
return default
raise ValueError(
f'No valid output from {num_samples} samples. Last error: {last_error}'
)
if len(samples) == 1:
return samples[0]

# Apply the consistency function.
if consistency_lm is None and len(lm_list) == 1:
consistency_lm = lm_list[0]

vote_fn = pg.typing.callable_ext.CallableWithOptionalKeywordArgs(
vote_fn, ['schema', 'lm']
)
return vote_fn(samples, schema=schema, lm=consistency_lm)
78 changes: 78 additions & 0 deletions langfun/core/structured/consistency_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2024 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for querying with consistency."""

import unittest

from langfun.core.llms import fake
from langfun.core.structured import consistency


class ConsistencyTest(unittest.TestCase):

def test_basic(self):
self.assertEqual(
consistency.query_with_consistency(
'Compute 1 + 2',
int,
lm=[
fake.StaticResponse('1'),
fake.StaticResponse('2'),
],
consistency_lm=fake.StaticResponse('3'),
num_samples=2,
),
3
)

def test_default_value(self):
self.assertIsNone(
consistency.query_with_consistency(
'Compute 1 + 2',
int,
default=None,
lm=[
fake.StaticResponse('ab'),
fake.StaticResponse('cd'),
],
num_samples=2,
),
)

def test_no_valid_output(self):
with self.assertRaisesRegex(ValueError, 'No valid output from .*'):
consistency.query_with_consistency(
'Compute 1 + 2',
int,
lm=[
fake.StaticResponse('ab'),
fake.StaticResponse('cd'),
],
num_samples=2,
)

def test_single_output(self):
self.assertEqual(
consistency.query_with_consistency(
'Compute 1 + 2',
int,
lm=[fake.StaticResponse('3'), fake.StaticResponse('abc')],
num_samples=1,
),
3
)


if __name__ == '__main__':
unittest.main()
12 changes: 3 additions & 9 deletions langfun/core/structured/querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def _query_structure_cls(

def query(
prompt: Union[str, lf.Template, Any],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
] = None,
schema: schema_lib.SchemaType | None = None,
default: Any = lf.RAISE_IF_HAS_ERROR,
*,
lm: lf.LanguageModel | None = None,
Expand Down Expand Up @@ -282,9 +280,7 @@ def _result(message: lf.Message):

def query_prompt(
prompt: Union[str, lf.Template, Any],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
] = None,
schema: schema_lib.SchemaType | None = None,
**kwargs,
) -> lf.Message:
"""Returns the final prompt sent to LLM for `lf.query`."""
Expand All @@ -295,9 +291,7 @@ def query_prompt(

def query_output(
response: Union[str, lf.Message],
schema: Union[
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
],
schema: schema_lib.SchemaType | None = None,
**kwargs,
) -> Any:
"""Returns the final output of `lf.query` from a provided LLM response."""
Expand Down
3 changes: 3 additions & 0 deletions langfun/core/structured/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def _html_tree_view_tooltip(
)


SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]


def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
"""Returns a list of top level value specs from a symbolic value."""
top_level_object_specs = []
Expand Down

0 comments on commit 95cb9b9

Please sign in to comment.