From e21e0ca5c1f759dc614af87ee04bb0487577e1b0 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Sat, 30 Nov 2024 14:46:24 -0800 Subject: [PATCH] Add `lf.query_with_consistency`. `lf.query_with_consistency` is a drop-in replacement for `lf.query` with compatible arguments and default behavior. It allows users to apply consistency methods to `lf.query` with different models and number of samples. The default voting function is LLM-based majority voting, users could plugin their own voting logic through the `vote_fn` argument. Example usage: sampling Gpt4 and Gemini2 each with 5 samples and obtain the final output through majority voting. ```python lf.query_with_consistency( 'compute 256 * 345', int, lm=[lf.llms.Gpt4(), lf.llms.Gemini2()], num_samples=5, ... ) ``` PiperOrigin-RevId: 701557382 --- langfun/__init__.py | 2 + langfun/core/structured/__init__.py | 2 + langfun/core/structured/consistency.py | 147 ++++++++++++++++++++ langfun/core/structured/consistency_test.py | 78 +++++++++++ langfun/core/structured/querying.py | 12 +- langfun/core/structured/schema.py | 3 + 6 files changed, 235 insertions(+), 9 deletions(-) create mode 100644 langfun/core/structured/consistency.py create mode 100644 langfun/core/structured/consistency_test.py diff --git a/langfun/__init__.py b/langfun/__init__.py index bad173a..db00c72 100644 --- a/langfun/__init__.py +++ b/langfun/__init__.py @@ -37,6 +37,8 @@ track_queries = structured.track_queries +query_with_consistency = structured.query_with_consistency + # Helper functions for input/output transformations based on # `lf.query` (e.g. jax-on-beam could use these for batch processing) query_prompt = structured.query_prompt diff --git a/langfun/core/structured/__init__.py b/langfun/core/structured/__init__.py index a32c207..5de77d6 100644 --- a/langfun/core/structured/__init__.py +++ b/langfun/core/structured/__init__.py @@ -60,6 +60,8 @@ from langfun.core.structured.querying import query_output from langfun.core.structured.querying import query_reward +from langfun.core.structured.consistency import query_with_consistency + from langfun.core.structured.description import describe from langfun.core.structured.completion import complete diff --git a/langfun/core/structured/consistency.py b/langfun/core/structured/consistency.py new file mode 100644 index 0000000..f7dd14f --- /dev/null +++ b/langfun/core/structured/consistency.py @@ -0,0 +1,147 @@ +# Copyright 2024 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""`lf.query` with consistency.""" + +from typing import Any, Callable, Type, Union + +import langfun.core as lf +from langfun.core.structured import mapping +from langfun.core.structured import querying +from langfun.core.structured import schema as schema_lib +import pyglove as pg + + +def majority_voting( + outputs: list[Any], + schema: schema_lib.SchemaType | None = None, + lm: lf.LanguageModel | None = None, +) -> Any: + return querying.query( + prompt=( + 'Derive an object from the following objects based on majority ' + 'voting: {{outputs}}.' + ), + schema=schema, + outputs=outputs, + lm=lm, + ) + + +ConsistencyFn = Union[ + # Signature: `fn(outputs) -> output` + Callable[[list[Any]], Any], + # Signature: `fn(outputs, schema) -> output` + Callable[[list[Any], schema_lib.SchemaType | None], Any], + # Signature: `fn(outputs, schema, lm) -> output` + Callable[ + [list[Any], schema_lib.SchemaType | None, lf.LanguageModel | None], Any + ] +] + + +def query_with_consistency( + prompt: Union[str, lf.Template, Any], + schema: Union[ + schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None + ] = None, + default: Any = lf.RAISE_IF_HAS_ERROR, + *, + vote_fn: ConsistencyFn = majority_voting, + lm: lf.LanguageModel | list[lf.LanguageModel] | None = None, + consistency_lm: lf.LanguageModel | None = None, + examples: list[mapping.MappingExample] | None = None, + cache_seed: int | None = 0, + num_samples: int = 5, + **kwargs, +) -> Any: + """`lf.query` with consistency. + + This function is a wrapper around `lf.query` to apply consistency among the + return values of multiple calls to `lf.query`. It takes a list of language + models as input, and returns the final object by applying a voting + function. The voting function takes a list of outputs and a schema as + arguments, and returns the final object with consistency applied. + + Args: + prompt: A str (may contain {{}} as template) as natural language input, or a + `pg.Symbolic` object as structured input as prompt to LLM. + schema: A type annotation as the schema for output object. If str (default), + the response will be a str in natural language. + default: The default value if parsing failed. If not specified, error will + be raised. + vote_fn: A function to vote for the final output among the return values of + individual calls to `lf.query`. It takes a list of outputs and a schema as + arguments, returns the final object with consistency applied. + lm: The language model to use. If not specified, the language model from + `lf.context` context manager will be used. + consistency_lm: The language model to use for consistency. If None, `lm` + will be used. + examples: An optional list of fewshot examples for helping parsing. If None, + the default one-shot example will be added. + cache_seed: The seed for the cache. + num_samples: The number of samples to obtain from each language model being + requested. + **kwargs: Additional arguments to pass to `lf.query`. + + Returns: + The final object for the requested schema, with consistency applied. + """ + def query(inputs): + lm, example_i = inputs + return querying.query( + prompt, schema, lm=lm, examples=examples, + # Usually num_examples should not be large, so we multiple the user + # provided cache seed by 100 to avoid collision. + cache_seed=None if cache_seed is None else cache_seed * 100 + example_i, + **kwargs, + ) + + if not isinstance(lm, list): + lm_list = [lm] + else: + lm_list = lm + + query_inputs = [] + for lm in lm_list: + query_inputs.extend([(lm, i) for i in range(num_samples)]) + + # Concurrently sample the outputs from the language models. + samples = [] + last_error = None + for _, output, error in lf.concurrent_map( + query, query_inputs, max_workers=max(64, len(lm_list) * num_samples), + silence_on_errors=mapping.MappingError + ): + if error is None: + samples.append(output) + else: + last_error = error + + if not samples: + if default is not lf.RAISE_IF_HAS_ERROR: + return default + raise ValueError( + f'No valid output from {num_samples} samples. Last error: {last_error}' + ) + if len(samples) == 1: + return samples[0] + + # Apply the consistency function. + if consistency_lm is None and len(lm_list) == 1: + consistency_lm = lm_list[0] + + vote_fn = pg.typing.callable_ext.CallableWithOptionalKeywordArgs( + vote_fn, ['schema', 'lm'] + ) + return vote_fn(samples, schema=schema, lm=consistency_lm) diff --git a/langfun/core/structured/consistency_test.py b/langfun/core/structured/consistency_test.py new file mode 100644 index 0000000..a2ec21e --- /dev/null +++ b/langfun/core/structured/consistency_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for querying with consistency.""" + +import unittest + +from langfun.core.llms import fake +from langfun.core.structured import consistency + + +class ConsistencyTest(unittest.TestCase): + + def test_basic(self): + self.assertEqual( + consistency.query_with_consistency( + 'Compute 1 + 2', + int, + lm=[ + fake.StaticResponse('1'), + fake.StaticResponse('2'), + ], + consistency_lm=fake.StaticResponse('3'), + num_samples=2, + ), + 3 + ) + + def test_default_value(self): + self.assertIsNone( + consistency.query_with_consistency( + 'Compute 1 + 2', + int, + default=None, + lm=[ + fake.StaticResponse('ab'), + fake.StaticResponse('cd'), + ], + num_samples=2, + ), + ) + + def test_no_valid_output(self): + with self.assertRaisesRegex(ValueError, 'No valid output from .*'): + consistency.query_with_consistency( + 'Compute 1 + 2', + int, + lm=[ + fake.StaticResponse('ab'), + fake.StaticResponse('cd'), + ], + num_samples=2, + ) + + def test_single_output(self): + self.assertEqual( + consistency.query_with_consistency( + 'Compute 1 + 2', + int, + lm=[fake.StaticResponse('3'), fake.StaticResponse('abc')], + num_samples=1, + ), + 3 + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/langfun/core/structured/querying.py b/langfun/core/structured/querying.py index b85a632..eee872f 100644 --- a/langfun/core/structured/querying.py +++ b/langfun/core/structured/querying.py @@ -105,9 +105,7 @@ def _query_structure_cls( def query( prompt: Union[str, lf.Template, Any], - schema: Union[ - schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None - ] = None, + schema: schema_lib.SchemaType | None = None, default: Any = lf.RAISE_IF_HAS_ERROR, *, lm: lf.LanguageModel | None = None, @@ -282,9 +280,7 @@ def _result(message: lf.Message): def query_prompt( prompt: Union[str, lf.Template, Any], - schema: Union[ - schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None - ] = None, + schema: schema_lib.SchemaType | None = None, **kwargs, ) -> lf.Message: """Returns the final prompt sent to LLM for `lf.query`.""" @@ -295,9 +291,7 @@ def query_prompt( def query_output( response: Union[str, lf.Message], - schema: Union[ - schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None - ], + schema: schema_lib.SchemaType | None = None, **kwargs, ) -> Any: """Returns the final output of `lf.query` from a provided LLM response.""" diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index c1677d9..6c7bc3c 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -227,6 +227,9 @@ def _html_tree_view_tooltip( ) +SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]] + + def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]: """Returns a list of top level value specs from a symbolic value.""" top_level_object_specs = []