From be3102885596d6d47f81199af8942ad0889a7a1a Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Fri, 13 Dec 2024 13:01:40 -0800 Subject: [PATCH] Clean up lf.structured. 1) Rename prompting.py to query.py as it's easier for users to locate the source code of `lf.query`. 2) Make helper classes without backward compatibility guarentee private. PiperOrigin-RevId: 705976528 --- docs/notebooks/langfun101.ipynb | 2 +- langfun/core/coding/python/correction.py | 4 +- langfun/core/structured/__init__.py | 30 ++--- langfun/core/structured/completion.py | 4 +- langfun/core/structured/completion_test.py | 8 +- langfun/core/structured/description.py | 4 +- langfun/core/structured/description_test.py | 6 +- .../core/structured/function_generation.py | 6 +- langfun/core/structured/parsing.py | 16 +-- langfun/core/structured/parsing_test.py | 16 +-- .../structured/{prompting.py => query.py} | 16 +-- .../{prompting_test.py => query_test.py} | 102 +++++++------- langfun/core/structured/schema.py | 127 +++++++++--------- langfun/core/structured/schema_test.py | 44 +++--- langfun/core/structured/scoring.py | 6 +- langfun/core/structured/tokenization.py | 4 +- 16 files changed, 191 insertions(+), 204 deletions(-) rename langfun/core/structured/{prompting.py => query.py} (97%) rename langfun/core/structured/{prompting_test.py => query_test.py} (91%) diff --git a/docs/notebooks/langfun101.ipynb b/docs/notebooks/langfun101.ipynb index 47e487cb..85b322ec 100644 --- a/docs/notebooks/langfun101.ipynb +++ b/docs/notebooks/langfun101.ipynb @@ -10,7 +10,7 @@ "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/langfun/blob/main/docs/notebooks/langfun101.ipynb)\n", "\n", - "Effective programming of Large Language Models (LLMs) demands a seamless integration of natural language text with structured data. Langfun, leveraging [PyGlove](https://github.com/google/pyglove)'s symbolic objects, provides a simple yet powerful interface for mapping between Python objects with the assistance of LLMs. The input/output objects may include natural language texts (in string form), structured data (objects of a specific class), modalities (such as images), and more. The unified API for accomplishing all conceivable mappings is [`lf.query`](https://github.com/google/langfun/blob/145f4a48603d2e6e19f9025e032ac3c86dfd3e35/langfun/core/structured/prompting.py#L103).\n" + "Effective programming of Large Language Models (LLMs) demands a seamless integration of natural language text with structured data. Langfun, leveraging [PyGlove](https://github.com/google/pyglove)'s symbolic objects, provides a simple yet powerful interface for mapping between Python objects with the assistance of LLMs. The input/output objects may include natural language texts (in string form), structured data (objects of a specific class), modalities (such as images), and more. The unified API for accomplishing all conceivable mappings is [`lf.query`](https://github.com/google/langfun/blob/145f4a48603d2e6e19f9025e032ac3c86dfd3e35/langfun/core/structured/query.py#L103).\n" ] }, { diff --git a/langfun/core/coding/python/correction.py b/langfun/core/coding/python/correction.py index d2c1d5dc..bcac5647 100644 --- a/langfun/core/coding/python/correction.py +++ b/langfun/core/coding/python/correction.py @@ -76,7 +76,7 @@ def run_with_correction( # Delay import at runtime to avoid circular depenency. # pylint: disable=g-import-not-at-top # pytype: disable=import-error - from langfun.core.structured import prompting + from langfun.core.structured import query as query_lib # pytype: enable=import-error # pylint: enable=g-import-not-at-top @@ -119,7 +119,7 @@ def result_and_error(code: str) -> tuple[Any, str | None]: # structure. try: # Disable autofix for code correction to avoid recursion. - correction = prompting.query( + correction = query_lib.query( CodeWithError(code=code, error=error), CorrectedCode, lm=lm, autofix=0 ) except errors.CodeError: diff --git a/langfun/core/structured/__init__.py b/langfun/core/structured/__init__.py index 685b067e..4a8e118f 100644 --- a/langfun/core/structured/__init__.py +++ b/langfun/core/structured/__init__.py @@ -36,12 +36,6 @@ from langfun.core.structured.schema import annotation from langfun.core.structured.schema import structure_from_python -from langfun.core.structured.schema import SchemaRepr -from langfun.core.structured.schema import SchemaJsonRepr -from langfun.core.structured.schema import SchemaPythonRepr -from langfun.core.structured.schema import ValueRepr -from langfun.core.structured.schema import ValueJsonRepr -from langfun.core.structured.schema import ValuePythonRepr from langfun.core.structured.schema import schema_repr from langfun.core.structured.schema import source_form from langfun.core.structured.schema import value_repr @@ -56,26 +50,18 @@ from langfun.core.structured.mapping import MappingError from langfun.core.structured.mapping import MappingExample -from langfun.core.structured.parsing import ParseStructure -from langfun.core.structured.parsing import ParseStructureJson -from langfun.core.structured.parsing import ParseStructurePython from langfun.core.structured.parsing import parse from langfun.core.structured.parsing import call -from langfun.core.structured.prompting import QueryStructure -from langfun.core.structured.prompting import QueryStructureJson -from langfun.core.structured.prompting import QueryStructurePython -from langfun.core.structured.prompting import query -from langfun.core.structured.prompting import query_prompt -from langfun.core.structured.prompting import query_output -from langfun.core.structured.prompting import query_reward -from langfun.core.structured.prompting import QueryInvocation -from langfun.core.structured.prompting import track_queries - -from langfun.core.structured.description import DescribeStructure -from langfun.core.structured.description import describe +from langfun.core.structured.query import query +from langfun.core.structured.query import query_prompt +from langfun.core.structured.query import query_output +from langfun.core.structured.query import query_reward +from langfun.core.structured.query import QueryInvocation +from langfun.core.structured.query import track_queries +from langfun.core.structured import query as query_lib -from langfun.core.structured.completion import CompleteStructure +from langfun.core.structured.description import describe from langfun.core.structured.completion import complete from langfun.core.structured.scoring import score diff --git a/langfun/core/structured/completion.py b/langfun/core/structured/completion.py index 741d767e..dda8f150 100644 --- a/langfun/core/structured/completion.py +++ b/langfun/core/structured/completion.py @@ -21,7 +21,7 @@ import pyglove as pg -class CompleteStructure(mapping.Mapping): +class _CompleteStructure(mapping.Mapping): """Complete structure by filling the missing fields.""" input: Annotated[ @@ -241,7 +241,7 @@ class Flight(pg.Object): Returns: The result based on the schema. """ - t = CompleteStructure( + t = _CompleteStructure( input=schema_lib.mark_missing(input_value), default=default, examples=examples, diff --git a/langfun/core/structured/completion_test.py b/langfun/core/structured/completion_test.py index 26a44f42..a93e6447 100644 --- a/langfun/core/structured/completion_test.py +++ b/langfun/core/structured/completion_test.py @@ -46,7 +46,7 @@ class TripPlan(pg.Object): class CompleteStructureTest(unittest.TestCase): def test_render_no_examples(self): - l = completion.CompleteStructure() + l = completion._CompleteStructure() input_value = schema_lib.mark_missing( TripPlan.partial( place='San Francisco', @@ -120,7 +120,7 @@ class Activity: ) def test_render_no_class_definitions(self): - l = completion.CompleteStructure() + l = completion._CompleteStructure() input_value = schema_lib.mark_missing( TripPlan.partial( place='San Francisco', @@ -200,7 +200,7 @@ def test_render_no_class_definitions(self): ) def test_render_with_examples(self): - l = completion.CompleteStructure() + l = completion._CompleteStructure() input_value = schema_lib.mark_missing( TripPlan.partial( place='San Francisco', @@ -411,7 +411,7 @@ class Animal(pg.Object): modalities.Image.from_bytes(b'image_of_elephant'), ) ) - l = completion.CompleteStructure( + l = completion._CompleteStructure( input=input_value, examples=[ mapping.MappingExample( diff --git a/langfun/core/structured/description.py b/langfun/core/structured/description.py index 536fdc8a..5d8d2978 100644 --- a/langfun/core/structured/description.py +++ b/langfun/core/structured/description.py @@ -22,7 +22,7 @@ @pg.use_init_args(['examples']) -class DescribeStructure(mapping.Mapping): +class _DescribeStructure(mapping.Mapping): """Describe a structured value in natural language.""" input_title = 'PYTHON_OBJECT' @@ -106,7 +106,7 @@ class Flight(pg.Object): Returns: The parsed result based on the schema. """ - return DescribeStructure( + return _DescribeStructure( input=value, context=context, examples=examples or default_describe_examples(), diff --git a/langfun/core/structured/description_test.py b/langfun/core/structured/description_test.py index dba8dd2e..4c74f768 100644 --- a/langfun/core/structured/description_test.py +++ b/langfun/core/structured/description_test.py @@ -36,7 +36,7 @@ class Itinerary(pg.Object): class DescribeStructureTest(unittest.TestCase): def test_render(self): - l = description_lib.DescribeStructure( + l = description_lib._DescribeStructure( input=Itinerary( day=1, type='daytime', @@ -137,7 +137,7 @@ def test_render_no_examples(self): ], hotel=None, ) - l = description_lib.DescribeStructure( + l = description_lib._DescribeStructure( input=value, context='1 day itinerary to SF' ) self.assertEqual( @@ -187,7 +187,7 @@ def test_render_no_context(self): ], hotel=None, ) - l = description_lib.DescribeStructure(input=value) + l = description_lib._DescribeStructure(input=value) self.assertEqual( l.render().text, inspect.cleandoc(""" diff --git a/langfun/core/structured/function_generation.py b/langfun/core/structured/function_generation.py index ef543808..40db7d34 100644 --- a/langfun/core/structured/function_generation.py +++ b/langfun/core/structured/function_generation.py @@ -21,7 +21,7 @@ from langfun.core import language_model from langfun.core import template from langfun.core.coding import python -from langfun.core.structured import prompting +from langfun.core.structured import query as query_lib import pyglove as pg @@ -39,7 +39,7 @@ class PythonFunctionSignature(pg.Object): unittest_examples = None for _ in range(num_retries): - r = prompting.query( + r = query_lib.query( PythonFunctionSignature(signature=signature), list[UnitTest], lm=lm, @@ -145,7 +145,7 @@ def calculate_area_circle(radius: float) -> float: last_error = None for _ in range(num_retries): try: - source_code = prompting.query( + source_code = query_lib.query( PythonFunctionPrompt(signature=signature), lm=lm ) f = python.evaluate(source_code, global_vars=context) diff --git a/langfun/core/structured/parsing.py b/langfun/core/structured/parsing.py index 1b5689a2..34d3312b 100644 --- a/langfun/core/structured/parsing.py +++ b/langfun/core/structured/parsing.py @@ -16,13 +16,13 @@ import langfun.core as lf from langfun.core.structured import mapping -from langfun.core.structured import prompting +from langfun.core.structured import query as query_lib from langfun.core.structured import schema as schema_lib import pyglove as pg @lf.use_init_args(['schema', 'default', 'examples']) -class ParseStructure(mapping.Mapping): +class _ParseStructure(mapping.Mapping): """Parse an object out from a natural language text.""" context_title = 'USER_REQUEST' @@ -37,7 +37,7 @@ class ParseStructure(mapping.Mapping): ] -class ParseStructureJson(ParseStructure): +class _ParseStructureJson(_ParseStructure): """Parse an object out from a NL text using JSON as the protocol.""" preamble = """ @@ -53,7 +53,7 @@ class ParseStructureJson(ParseStructure): output_title = 'JSON' -class ParseStructurePython(ParseStructure): +class _ParseStructurePython(_ParseStructure): """Parse an object out from a NL text using Python as the protocol.""" preamble = """ @@ -271,7 +271,7 @@ def call( return lm_output if returns_message else lm_output.text # Call `parsing_lm` for structured parsing. - parsing_message = prompting.query( + parsing_message = query_lib.query( lm_output.text, schema, examples=parsing_examples, @@ -293,11 +293,11 @@ def call( def _parse_structure_cls( protocol: schema_lib.SchemaProtocol, -) -> Type[ParseStructure]: +) -> Type[_ParseStructure]: if protocol == 'json': - return ParseStructureJson + return _ParseStructureJson elif protocol == 'python': - return ParseStructurePython + return _ParseStructurePython else: raise ValueError(f'Unknown protocol: {protocol!r}.') diff --git a/langfun/core/structured/parsing_test.py b/langfun/core/structured/parsing_test.py index 94667023..802d0bd1 100644 --- a/langfun/core/structured/parsing_test.py +++ b/langfun/core/structured/parsing_test.py @@ -37,7 +37,7 @@ class Itinerary(pg.Object): class ParseStructurePythonTest(unittest.TestCase): def test_render_no_examples(self): - l = parsing.ParseStructurePython(int) + l = parsing._ParseStructurePython(int) m = lf.AIMessage('12 / 6 + 2 = 4') self.assertEqual( l.render(input=m, context='Compute 12 / 6 + 2.').text, @@ -62,7 +62,7 @@ def test_render_no_examples(self): ) def test_render_no_context(self): - l = parsing.ParseStructurePython(int) + l = parsing._ParseStructurePython(int) m = lf.AIMessage('12 / 6 + 2 = 4') self.assertEqual( @@ -85,7 +85,7 @@ def test_render_no_context(self): ) def test_render(self): - l = parsing.ParseStructurePython( + l = parsing._ParseStructurePython( int, examples=[ mapping.MappingExample( @@ -212,7 +212,7 @@ def test_invocation(self): ), override_attrs=True, ): - l = parsing.ParseStructurePython( + l = parsing._ParseStructurePython( [Itinerary], examples=[ mapping.MappingExample( @@ -295,7 +295,7 @@ def test_parse(self): class ParseStructureJsonTest(unittest.TestCase): def test_render_no_examples(self): - l = parsing.ParseStructureJson(int) + l = parsing._ParseStructureJson(int) m = lf.AIMessage('12 / 6 + 2 = 4') self.assertEqual( l.render(input=m, context='Compute 12 / 6 + 2.').text, @@ -320,7 +320,7 @@ def test_render_no_examples(self): ) def test_render_no_context(self): - l = parsing.ParseStructureJson(int) + l = parsing._ParseStructureJson(int) m = lf.AIMessage('12 / 6 + 2 = 4') self.assertEqual( @@ -343,7 +343,7 @@ def test_render_no_context(self): ) def test_render(self): - l = parsing.ParseStructureJson( + l = parsing._ParseStructureJson( int, examples=[ mapping.MappingExample( @@ -504,7 +504,7 @@ def test_invocation(self): override_attrs=True, ): message = lf.LangFunc(lm_input)() - l = parsing.ParseStructureJson( + l = parsing._ParseStructureJson( [Itinerary], examples=[ mapping.MappingExample( diff --git a/langfun/core/structured/prompting.py b/langfun/core/structured/query.py similarity index 97% rename from langfun/core/structured/prompting.py rename to langfun/core/structured/query.py index 50059923..9dc7ee32 100644 --- a/langfun/core/structured/prompting.py +++ b/langfun/core/structured/query.py @@ -26,7 +26,7 @@ @lf.use_init_args(['schema', 'default', 'examples']) -class QueryStructure(mapping.Mapping): +class _QueryStructure(mapping.Mapping): """Query an object out from a natural language text.""" context_title = 'CONTEXT' @@ -38,7 +38,7 @@ class QueryStructure(mapping.Mapping): ] -class QueryStructureJson(QueryStructure): +class _QueryStructureJson(_QueryStructure): """Query a structured value using JSON as the protocol.""" preamble = """ @@ -52,10 +52,10 @@ class QueryStructureJson(QueryStructure): 1 + 1 = {{ schema_title }}: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}} {{ output_title}}: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}} """ protocol = 'json' @@ -63,7 +63,7 @@ class QueryStructureJson(QueryStructure): output_title = 'JSON' -class QueryStructurePython(QueryStructure): +class _QueryStructurePython(_QueryStructure): """Query a structured value using Python as the protocol.""" preamble = """ @@ -94,11 +94,11 @@ class Answer: def _query_structure_cls( protocol: schema_lib.SchemaProtocol, -) -> Type[QueryStructure]: +) -> Type[_QueryStructure]: if protocol == 'json': - return QueryStructureJson + return _QueryStructureJson elif protocol == 'python': - return QueryStructurePython + return _QueryStructurePython else: raise ValueError(f'Unknown protocol: {protocol!r}.') diff --git a/langfun/core/structured/prompting_test.py b/langfun/core/structured/query_test.py similarity index 91% rename from langfun/core/structured/prompting_test.py rename to langfun/core/structured/query_test.py index 2e54b1df..d075c4cc 100644 --- a/langfun/core/structured/prompting_test.py +++ b/langfun/core/structured/query_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for structured prompting.""" +"""Tests for structured query_lib.""" import inspect import math @@ -23,7 +23,7 @@ from langfun.core.llms import fake from langfun.core.llms.cache import in_memory from langfun.core.structured import mapping -from langfun.core.structured import prompting +from langfun.core.structured import query as query_lib import pyglove as pg @@ -51,7 +51,7 @@ def assert_render( expected_modalities: int = 0, **kwargs, ): - m = prompting.query( + m = query_lib.query( prompt, schema=schema, examples=examples, **kwargs, returns_message=True ) @@ -67,14 +67,14 @@ def assert_render( def test_call(self): lm = fake.StaticSequence(['1']) - self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm), 1) + self.assertEqual(query_lib.query('what is 1 + 0', int, lm=lm), 1) # Testing calling the same `lm` without copy. with self.assertRaises(IndexError): - prompting.query('what is 1 + 2', int, lm=lm) + query_lib.query('what is 1 + 2', int, lm=lm) self.assertEqual( - prompting.query( + query_lib.query( 'what is 1 + 0', int, lm=lm.clone(), returns_message=True ), lf.AIMessage( @@ -88,17 +88,17 @@ def test_call(self): ), ) self.assertEqual( - prompting.query( + query_lib.query( lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone() ), 1, ) self.assertEqual( - prompting.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()), + query_lib.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()), 1, ) self.assertEqual( - prompting.query( + query_lib.query( 'what is {{x}} + {{y}}', x=1, y=0, @@ -107,7 +107,7 @@ def test_call(self): 'The answer is one.', ) self.assertEqual( - prompting.query( + query_lib.query( Activity.partial(), lm=fake.StaticResponse('Activity(description="hello")'), ), @@ -329,11 +329,11 @@ def test_structure_with_modality_and_examples_to_structure_render(self): def test_bad_protocol(self): with self.assertRaisesRegex(ValueError, 'Unknown protocol'): - prompting.query('what is 1 + 1', int, protocol='text') + query_lib.query('what is 1 + 1', int, protocol='text') def test_query_prompt(self): self.assertEqual( - prompting.query_prompt('what is this?', int), + query_lib.query_prompt('what is this?', int), inspect.cleandoc(""" Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE. @@ -368,14 +368,14 @@ class Answer: def test_query_prompt_with_metadata(self): self.assertIn( 'x', - prompting.query_prompt( + query_lib.query_prompt( 'what is this?', metadata_x=1 ).metadata ) self.assertIn( 'x', - prompting.query_prompt( + query_lib.query_prompt( 'what is this?', int, metadata_x=1 @@ -383,7 +383,7 @@ def test_query_prompt_with_metadata(self): ) def test_query_prompt_with_unrooted_template(self): - output = prompting.query_prompt( + output = query_lib.query_prompt( pg.Dict( input=lf.Template( 'what is {{image}}', @@ -395,7 +395,7 @@ def test_query_prompt_with_unrooted_template(self): def test_query_output(self): self.assertEqual( - prompting.query_output( + query_lib.query_output( lf.AIMessage('1'), int, ), @@ -414,7 +414,7 @@ def __reward__(self, inputs: lf.Template) -> None: # Case 1: Reward function based on input and output. self.assertEqual( - prompting.query_reward( + query_lib.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), schema=Answer, @@ -425,7 +425,7 @@ def __reward__(self, inputs: lf.Template) -> None: 1.0 ) self.assertEqual( - prompting.query_reward( + query_lib.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=2, y=3), output=Answer(final_answer=2), @@ -445,7 +445,7 @@ def __reward__(self, inputs: lf.Template, expected_output: 'Answer2'): ) self.assertEqual( - prompting.query_reward( + query_lib.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output=Answer2(final_answer=2), @@ -470,7 +470,7 @@ def __reward__(self, ) * metadata['weight'] self.assertEqual( - prompting.query_reward( + query_lib.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output=Answer3(final_answer=2), @@ -486,7 +486,7 @@ class Answer4(pg.Object): final_answer: int self.assertIsNone( - prompting.query_reward( + query_lib.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output=Answer4(final_answer=2), @@ -497,7 +497,7 @@ class Answer4(pg.Object): # Case 5: Not a structured output. self.assertIsNone( - prompting.query_reward( + query_lib.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output='2', @@ -516,7 +516,7 @@ def __reward__(self): with self.assertRaisesRegex( TypeError, '.*Answer5.__reward__` should have signature' ): - prompting.query_reward( + query_lib.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output=Answer5(final_answer=2), @@ -528,7 +528,7 @@ def __reward__(self): class QueryStructurePythonTest(unittest.TestCase): def test_render_no_examples(self): - l = prompting.QueryStructurePython( + l = query_lib._QueryStructurePython( input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int ) self.assertEqual( @@ -565,7 +565,7 @@ class Answer: ) def test_render(self): - l = prompting.QueryStructurePython( + l = query_lib._QueryStructurePython( input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int, examples=[ @@ -675,7 +675,7 @@ def test_invocation(self): ), override_attrs=True, ): - l = prompting.QueryStructurePython( + l = query_lib._QueryStructurePython( input=lm_input, schema=[Itinerary], examples=[ @@ -712,7 +712,7 @@ def test_bad_response(self): mapping.MappingError, 'name .* is not defined', ): - prompting.query('Compute 1 + 2', int) + query_lib.query('Compute 1 + 2', int) def test_autofix(self): lm = fake.StaticSequence([ @@ -723,7 +723,7 @@ def test_autofix(self): ) """), ]) - self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1) + self.assertEqual(query_lib.query('what is 1 + 0', int, lm=lm, autofix=3), 1) def test_response_postprocess(self): with lf.context( @@ -731,12 +731,12 @@ def test_response_postprocess(self): override_attrs=True, ): self.assertEqual( - prompting.query( + query_lib.query( 'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]), '3' ) self.assertEqual( - prompting.query( + query_lib.query( 'Compute 1 + 2', int, response_postprocess=lambda x: x.split('\n')[1]), 3 @@ -746,7 +746,7 @@ def test_response_postprocess(self): class QueryStructureJsonTest(unittest.TestCase): def test_render_no_examples(self): - l = prompting.QueryStructureJson( + l = query_lib._QueryStructureJson( input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int ) self.assertEqual( @@ -762,10 +762,10 @@ def test_render_no_examples(self): 1 + 1 = SCHEMA: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}} JSON: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}} INPUT_OBJECT: Compute 12 / 6 + 2. @@ -778,7 +778,7 @@ def test_render_no_examples(self): ) def test_render(self): - l = prompting.QueryStructureJson( + l = query_lib._QueryStructureJson( input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int, examples=[ @@ -799,10 +799,10 @@ def test_render(self): 1 + 1 = SCHEMA: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}} JSON: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}} INPUT_OBJECT: What is the answer of 1 plus 1? @@ -913,7 +913,7 @@ def test_invocation(self): ), override_attrs=True, ): - l = prompting.QueryStructureJson( + l = query_lib._QueryStructureJson( input=lm_input, schema=[Itinerary], examples=[ @@ -951,14 +951,14 @@ def test_bad_transform(self): mapping.MappingError, 'No JSON dict in the output', ): - prompting.query('Compute 1 + 2', int, protocol='json', cache_seed=1) + query_lib.query('Compute 1 + 2', int, protocol='json', cache_seed=1) # Make sure bad mapping does not impact cache. self.assertEqual(len(cache), 0) def test_query(self): lm = fake.StaticSequence(['{"result": 1}']) self.assertEqual( - prompting.query('what is 1 + 0', int, lm=lm, protocol='json'), 1 + query_lib.query('what is 1 + 0', int, lm=lm, protocol='json'), 1 ) @@ -968,8 +968,8 @@ def test_to_html(self): lm = fake.StaticSequence([ 'Activity(description="hi")', ]) - with prompting.track_queries() as queries: - prompting.query('foo', Activity, lm=lm) + with query_lib.track_queries() as queries: + query_lib.query('foo', Activity, lm=lm) self.assertIn('schema', queries[0].to_html_str()) @@ -981,10 +981,10 @@ def test_include_child_scopes(self): 'bar', 'Activity(description="hi")', ]) - with prompting.track_queries() as queries: - prompting.query('foo', lm=lm) - with prompting.track_queries() as child_queries: - prompting.query('give me an activity', Activity, lm=lm) + with query_lib.track_queries() as queries: + query_lib.query('foo', lm=lm) + with query_lib.track_queries() as child_queries: + query_lib.query('give me an activity', Activity, lm=lm) self.assertEqual(len(queries), 2) self.assertTrue(pg.eq(queries[0].input, lf.Template('foo'))) @@ -1008,10 +1008,10 @@ def test_exclude_child_scopes(self): 'bar', 'Activity(description="hi")', ]) - with prompting.track_queries(include_child_scopes=False) as queries: - prompting.query('foo', lm=lm) - with prompting.track_queries(include_child_scopes=False) as child_queries: - prompting.query('give me an activity', Activity, lm=lm) + with query_lib.track_queries(include_child_scopes=False) as queries: + query_lib.query('foo', lm=lm) + with query_lib.track_queries(include_child_scopes=False) as child_queries: + query_lib.query('give me an activity', Activity, lm=lm) self.assertEqual(len(queries), 1) self.assertTrue(pg.eq(queries[0].input, lf.Template('foo'))) @@ -1030,13 +1030,13 @@ def test_exclude_child_scopes(self): def test_concurrent_map(self): def make_query(prompt): - _ = prompting.query(prompt, lm=lm) + _ = query_lib.query(prompt, lm=lm) lm = fake.StaticSequence([ 'foo', 'bar', ]) - with prompting.track_queries() as queries: + with query_lib.track_queries() as queries: list(lf.concurrent_map(make_query, ['a', 'b'])) self.assertEqual(len(queries), 2) diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index ad30c232..8705d457 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -339,7 +339,7 @@ def schema_spec(noneable: bool = False) -> pg.typing.ValueSpec: # pylint: disab # -class SchemaRepr(metaclass=abc.ABCMeta): +class _SchemaRepr(metaclass=abc.ABCMeta): """Base class for schema representation.""" @abc.abstractmethod @@ -347,7 +347,7 @@ def repr(self, schema: Schema) -> str: """Returns the representation of the schema.""" -class SchemaPythonRepr(SchemaRepr): +class _SchemaPythonRepr(_SchemaRepr): """Python-representation for a schema.""" def repr( @@ -390,7 +390,7 @@ def result_definition(self, schema: Schema) -> str: def source_form(value, markdown: bool = False) -> str: """Returns the source code form of an object.""" - return ValuePythonRepr().repr(value, markdown=markdown) + return _ValuePythonRepr().repr(value, markdown=markdown) def class_definitions( @@ -600,7 +600,7 @@ def annotation( return x -class SchemaJsonRepr(SchemaRepr): +class _SchemaJsonRepr(_SchemaRepr): """JSON-representation for a schema.""" def repr(self, schema: Schema, **kwargs) -> str: @@ -654,7 +654,7 @@ def _visit(node: Any) -> None: # -class ValueRepr(metaclass=abc.ABCMeta): +class _ValueRepr(metaclass=abc.ABCMeta): """Base class for value representation.""" @abc.abstractmethod @@ -666,7 +666,7 @@ def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any: """Parse a LM generated text into a structured value.""" -class ValuePythonRepr(ValueRepr): +class _ValuePythonRepr(_ValueRepr): """Python-representation for value.""" def repr(self, @@ -681,7 +681,7 @@ def repr(self, if inspect.isclass(value): cls_schema = Schema.from_value(value) if isinstance(cls_schema.spec, pg.typing.Object): - object_code = SchemaPythonRepr().class_definitions( + object_code = _SchemaPythonRepr().class_definitions( cls_schema, markdown=markdown, # We add `pg.Object` as additional dependencies to the class @@ -692,7 +692,7 @@ def repr(self, assert object_code is not None return object_code else: - object_code = SchemaPythonRepr().result_definition(cls_schema) + object_code = _SchemaPythonRepr().result_definition(cls_schema) elif isinstance(value, lf.Template): return str(value) else: @@ -778,7 +778,7 @@ def __str__(self) -> str: return r.getvalue() -class ValueJsonRepr(ValueRepr): +class _ValueJsonRepr(_ValueRepr): """JSON-representation for value.""" def repr(self, value: Any, schema: Schema | None = None, **kwargs) -> str: @@ -789,7 +789,7 @@ def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any: """Parse a JSON string into a structured object.""" del schema try: - text = self.cleanup_json(text) + text = cleanup_json(text) v = pg.from_json_str(text, **kwargs) except Exception as e: raise JsonError(text, e) # pylint: disable=raise-missing-from @@ -801,71 +801,72 @@ def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any: )) return v['result'] - def cleanup_json(self, json_str: str) -> str: - """Clean up the LM responded JSON string.""" - # Treatments: - # 1. Extract the JSON string with a top-level dict from the response. - # This prevents the leading and trailing texts in the response to - # be counted as part of the JSON. - # 2. Escape new lines in JSON values. - - curly_brackets = 0 - under_json = False - under_str = False - str_begin = -1 - - cleaned = io.StringIO() - for i, c in enumerate(json_str): - if c == '{' and not under_str: - cleaned.write(c) - curly_brackets += 1 - under_json = True - continue - elif not under_json: - continue - if c == '}' and not under_str: - cleaned.write(c) - curly_brackets -= 1 - if curly_brackets == 0: - break - elif c == '"' and json_str[i - 1] != '\\': - under_str = not under_str - if under_str: - str_begin = i - else: - assert str_begin > 0 - str_value = json_str[str_begin : i + 1].replace('\n', '\\n') - cleaned.write(str_value) - str_begin = -1 - elif not under_str: - cleaned.write(c) - - if not under_json: - raise ValueError(f'No JSON dict in the output: {json_str}') - - if curly_brackets > 0: - raise ValueError( - f'Malformated JSON: missing {curly_brackets} closing curly braces.' - ) +def cleanup_json(json_str: str) -> str: + """Clean up the LM responded JSON string.""" + # Treatments: + # 1. Extract the JSON string with a top-level dict from the response. + # This prevents the leading and trailing texts in the response to + # be counted as part of the JSON. + # 2. Escape new lines in JSON values. + + curly_brackets = 0 + under_json = False + under_str = False + str_begin = -1 + + cleaned = io.StringIO() + for i, c in enumerate(json_str): + if c == '{' and not under_str: + cleaned.write(c) + curly_brackets += 1 + under_json = True + continue + elif not under_json: + continue + + if c == '}' and not under_str: + cleaned.write(c) + curly_brackets -= 1 + if curly_brackets == 0: + break + elif c == '"' and json_str[i - 1] != '\\': + under_str = not under_str + if under_str: + str_begin = i + else: + assert str_begin > 0 + str_value = json_str[str_begin : i + 1].replace('\n', '\\n') + cleaned.write(str_value) + str_begin = -1 + elif not under_str: + cleaned.write(c) + + if not under_json: + raise ValueError(f'No JSON dict in the output: {json_str}') + + if curly_brackets > 0: + raise ValueError( + f'Malformated JSON: missing {curly_brackets} closing curly braces.' + ) - return cleaned.getvalue() + return cleaned.getvalue() -def schema_repr(protocol: SchemaProtocol) -> SchemaRepr: +def schema_repr(protocol: SchemaProtocol) -> _SchemaRepr: """Gets a SchemaRepr object from protocol.""" if protocol == 'json': - return SchemaJsonRepr() + return _SchemaJsonRepr() elif protocol == 'python': - return SchemaPythonRepr() + return _SchemaPythonRepr() raise ValueError(f'Unsupported protocol: {protocol}.') -def value_repr(protocol: SchemaProtocol) -> ValueRepr: +def value_repr(protocol: SchemaProtocol) -> _ValueRepr: if protocol == 'json': - return ValueJsonRepr() + return _ValueJsonRepr() elif protocol == 'python': - return ValuePythonRepr() + return _ValuePythonRepr() raise ValueError(f'Unsupported protocol: {protocol}.') diff --git a/langfun/core/structured/schema_test.py b/langfun/core/structured/schema_test.py index 841aefe3..7e173cd9 100644 --- a/langfun/core/structured/schema_test.py +++ b/langfun/core/structured/schema_test.py @@ -572,7 +572,7 @@ def bar_value(self) -> str: schema = schema_lib.Schema([B]) self.assertEqual( - schema_lib.SchemaPythonRepr().class_definitions(schema), + schema_lib._SchemaPythonRepr().class_definitions(schema), inspect.cleandoc(''' class Foo: x: int @@ -599,11 +599,11 @@ def foo_value(self) -> int: ) self.assertEqual( - schema_lib.SchemaPythonRepr().result_definition(schema), 'list[B]' + schema_lib._SchemaPythonRepr().result_definition(schema), 'list[B]' ) self.assertEqual( - schema_lib.SchemaPythonRepr().repr(schema), + schema_lib._SchemaPythonRepr().repr(schema), inspect.cleandoc(''' list[B] @@ -633,7 +633,7 @@ def foo_value(self) -> int: '''), ) self.assertEqual( - schema_lib.SchemaPythonRepr().repr( + schema_lib._SchemaPythonRepr().repr( schema, include_result_definition=False, markdown=False, @@ -669,7 +669,7 @@ class SchemaJsonReprTest(unittest.TestCase): def test_repr(self): schema = schema_lib.Schema([{'x': Itinerary}]) self.assertEqual( - schema_lib.SchemaJsonRepr().repr(schema), + schema_lib._SchemaJsonRepr().repr(schema), ( '{"result": [{"x": {"_type": "Itinerary", "day":' ' int(min=1), "type": "daytime" | "nighttime", "activities":' @@ -692,21 +692,21 @@ class A(pg.Object): y: str | None self.assertEqual( - schema_lib.ValuePythonRepr().repr(1, schema_lib.Schema(int)), + schema_lib._ValuePythonRepr().repr(1, schema_lib.Schema(int)), '```python\n1\n```' ) self.assertEqual( - schema_lib.ValuePythonRepr().repr(lf.Template('hi, {{a}}', a='foo')), + schema_lib._ValuePythonRepr().repr(lf.Template('hi, {{a}}', a='foo')), 'hi, foo' ) self.assertEqual( - schema_lib.ValuePythonRepr().repr( + schema_lib._ValuePythonRepr().repr( A([Foo(1), Foo(2)], 'bar'), schema_lib.Schema(A), markdown=False, ), "A(foo=[Foo(x=1), Foo(x=2)], y='bar')", ) self.assertEqual( - schema_lib.ValuePythonRepr().repr(A), + schema_lib._ValuePythonRepr().repr(A), inspect.cleandoc(""" ```python class Foo(Object): @@ -729,7 +729,7 @@ class A(pg.Object): y: str | None self.assertEqual( - schema_lib.ValuePythonRepr().parse( + schema_lib._ValuePythonRepr().parse( "A(foo=[Foo(x=1), Foo(x=2)], y='bar')", schema_lib.Schema(A) ), A([Foo(1), Foo(2)], y='bar'), @@ -744,7 +744,7 @@ class A(pg.Object): y: str | None self.assertEqual( - schema_lib.ValuePythonRepr().parse( + schema_lib._ValuePythonRepr().parse( "A(foo=[Foo(x=1), Foo(x=2)], y='bar'", schema_lib.Schema(A), autofix=1, @@ -760,7 +760,7 @@ class A(pg.Object): def test_parse_class_def(self): self.assertTrue( inspect.isclass( - schema_lib.ValuePythonRepr().parse( + schema_lib._ValuePythonRepr().parse( """ class A: x: Dict[str, Any] @@ -775,10 +775,10 @@ class A: class ValueJsonReprTest(unittest.TestCase): def test_repr(self): - self.assertEqual(schema_lib.ValueJsonRepr().repr(1), '{"result": 1}') + self.assertEqual(schema_lib._ValueJsonRepr().repr(1), '{"result": 1}') def assert_parse(self, inputs, output) -> None: - self.assertEqual(schema_lib.ValueJsonRepr().parse(inputs), output) + self.assertEqual(schema_lib._ValueJsonRepr().parse(inputs), output) def test_parse_basics(self): self.assert_parse('{"result": 1}', 1) @@ -797,13 +797,13 @@ def test_parse_basics(self): with self.assertRaisesRegex( schema_lib.JsonError, 'JSONDecodeError' ): - schema_lib.ValueJsonRepr().parse('{"abc", 1}') + schema_lib._ValueJsonRepr().parse('{"abc", 1}') with self.assertRaisesRegex( schema_lib.JsonError, 'The root node of the JSON must be a dict with key `result`' ): - schema_lib.ValueJsonRepr().parse('{"abc": 1}') + schema_lib._ValueJsonRepr().parse('{"abc": 1}') def test_parse_with_surrounding_texts(self): self.assert_parse('The answer is {"result": 1}.', 1) @@ -823,30 +823,30 @@ def test_parse_with_malformated_json(self): with self.assertRaisesRegex( schema_lib.JsonError, 'No JSON dict in the output' ): - schema_lib.ValueJsonRepr().parse('The answer is 1.') + schema_lib._ValueJsonRepr().parse('The answer is 1.') with self.assertRaisesRegex( schema_lib.JsonError, 'Malformated JSON: missing .* closing curly braces' ): - schema_lib.ValueJsonRepr().parse('{"result": 1') + schema_lib._ValueJsonRepr().parse('{"result": 1') class ProtocolTest(unittest.TestCase): def test_schema_repr(self): self.assertIsInstance( - schema_lib.schema_repr('json'), schema_lib.SchemaJsonRepr) + schema_lib.schema_repr('json'), schema_lib._SchemaJsonRepr) self.assertIsInstance( - schema_lib.schema_repr('python'), schema_lib.SchemaPythonRepr) + schema_lib.schema_repr('python'), schema_lib._SchemaPythonRepr) with self.assertRaisesRegex(ValueError, 'Unsupported protocol'): schema_lib.schema_repr('text') def test_value_repr(self): self.assertIsInstance( - schema_lib.value_repr('json'), schema_lib.ValueJsonRepr) + schema_lib.value_repr('json'), schema_lib._ValueJsonRepr) self.assertIsInstance( - schema_lib.value_repr('python'), schema_lib.ValuePythonRepr) + schema_lib.value_repr('python'), schema_lib._ValuePythonRepr) with self.assertRaisesRegex(ValueError, 'Unsupported protocol'): schema_lib.value_repr('text') diff --git a/langfun/core/structured/scoring.py b/langfun/core/structured/scoring.py index d7bd9528..f83ebc0b 100644 --- a/langfun/core/structured/scoring.py +++ b/langfun/core/structured/scoring.py @@ -17,7 +17,7 @@ import langfun.core as lf from langfun.core.structured import mapping -from langfun.core.structured import prompting +from langfun.core.structured import query as query_lib from langfun.core.structured import schema as schema_lib import pyglove as pg @@ -101,7 +101,7 @@ class Answer(pg.Object): prompts = [] for p in prompt: prompts.append( - prompting.query_prompt( + query_lib.query_prompt( p, schema, examples=examples, @@ -111,7 +111,7 @@ class Answer(pg.Object): ) input_message = prompts else: - input_message = prompting.query_prompt( + input_message = query_lib.query_prompt( prompt, schema, examples=examples, diff --git a/langfun/core/structured/tokenization.py b/langfun/core/structured/tokenization.py index 14c12486..516e2a33 100644 --- a/langfun/core/structured/tokenization.py +++ b/langfun/core/structured/tokenization.py @@ -17,7 +17,7 @@ import langfun.core as lf from langfun.core.structured import mapping -from langfun.core.structured import prompting +from langfun.core.structured import query as query_lib from langfun.core.structured import schema as schema_lib import pyglove as pg @@ -48,7 +48,7 @@ def tokenize( Returns: A list of (text, token_id) tuples. """ - input_message = prompting.query_prompt( + input_message = query_lib.query_prompt( prompt, schema, examples=examples,