From 51cb7d554a5e927cceb66c9741036fdd38cebb11 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Mon, 16 Dec 2024 12:37:45 -0800 Subject: [PATCH] Clean up lf.structured. 1) Rename prompting.py to querying.py as it's easier for users to locate the source code of `lf.query`. 2) Make helper classes without backward compatibility guarentee private. PiperOrigin-RevId: 706800418 --- docs/notebooks/langfun101.ipynb | 2 +- langfun/core/coding/python/correction.py | 4 +- langfun/core/structured/__init__.py | 29 ++--- langfun/core/structured/completion.py | 4 +- langfun/core/structured/completion_test.py | 8 +- langfun/core/structured/description.py | 4 +- langfun/core/structured/description_test.py | 6 +- .../core/structured/function_generation.py | 6 +- langfun/core/structured/parsing.py | 18 ++-- langfun/core/structured/parsing_test.py | 16 +-- .../structured/{prompting.py => querying.py} | 18 ++-- .../{prompting_test.py => querying_test.py} | 102 +++++++++--------- langfun/core/structured/schema.py | 101 ++++++++--------- langfun/core/structured/scoring.py | 6 +- langfun/core/structured/tokenization.py | 4 +- 15 files changed, 157 insertions(+), 171 deletions(-) rename langfun/core/structured/{prompting.py => querying.py} (97%) rename langfun/core/structured/{prompting_test.py => querying_test.py} (91%) diff --git a/docs/notebooks/langfun101.ipynb b/docs/notebooks/langfun101.ipynb index 47e487c..ddeb2d1 100644 --- a/docs/notebooks/langfun101.ipynb +++ b/docs/notebooks/langfun101.ipynb @@ -10,7 +10,7 @@ "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/langfun/blob/main/docs/notebooks/langfun101.ipynb)\n", "\n", - "Effective programming of Large Language Models (LLMs) demands a seamless integration of natural language text with structured data. Langfun, leveraging [PyGlove](https://github.com/google/pyglove)'s symbolic objects, provides a simple yet powerful interface for mapping between Python objects with the assistance of LLMs. The input/output objects may include natural language texts (in string form), structured data (objects of a specific class), modalities (such as images), and more. The unified API for accomplishing all conceivable mappings is [`lf.query`](https://github.com/google/langfun/blob/145f4a48603d2e6e19f9025e032ac3c86dfd3e35/langfun/core/structured/prompting.py#L103).\n" + "Effective programming of Large Language Models (LLMs) demands a seamless integration of natural language text with structured data. Langfun, leveraging [PyGlove](https://github.com/google/pyglove)'s symbolic objects, provides a simple yet powerful interface for mapping between Python objects with the assistance of LLMs. The input/output objects may include natural language texts (in string form), structured data (objects of a specific class), modalities (such as images), and more. The unified API for accomplishing all conceivable mappings is [`lf.query`](https://github.com/google/langfun/blob/145f4a48603d2e6e19f9025e032ac3c86dfd3e35/langfun/core/structured/querying.py#L103).\n" ] }, { diff --git a/langfun/core/coding/python/correction.py b/langfun/core/coding/python/correction.py index d2c1d5d..fcb7289 100644 --- a/langfun/core/coding/python/correction.py +++ b/langfun/core/coding/python/correction.py @@ -76,7 +76,7 @@ def run_with_correction( # Delay import at runtime to avoid circular depenency. # pylint: disable=g-import-not-at-top # pytype: disable=import-error - from langfun.core.structured import prompting + from langfun.core.structured import querying # pytype: enable=import-error # pylint: enable=g-import-not-at-top @@ -119,7 +119,7 @@ def result_and_error(code: str) -> tuple[Any, str | None]: # structure. try: # Disable autofix for code correction to avoid recursion. - correction = prompting.query( + correction = querying.query( CodeWithError(code=code, error=error), CorrectedCode, lm=lm, autofix=0 ) except errors.CodeError: diff --git a/langfun/core/structured/__init__.py b/langfun/core/structured/__init__.py index 685b067..a32c207 100644 --- a/langfun/core/structured/__init__.py +++ b/langfun/core/structured/__init__.py @@ -36,12 +36,6 @@ from langfun.core.structured.schema import annotation from langfun.core.structured.schema import structure_from_python -from langfun.core.structured.schema import SchemaRepr -from langfun.core.structured.schema import SchemaJsonRepr -from langfun.core.structured.schema import SchemaPythonRepr -from langfun.core.structured.schema import ValueRepr -from langfun.core.structured.schema import ValueJsonRepr -from langfun.core.structured.schema import ValuePythonRepr from langfun.core.structured.schema import schema_repr from langfun.core.structured.schema import source_form from langfun.core.structured.schema import value_repr @@ -56,26 +50,17 @@ from langfun.core.structured.mapping import MappingError from langfun.core.structured.mapping import MappingExample -from langfun.core.structured.parsing import ParseStructure -from langfun.core.structured.parsing import ParseStructureJson -from langfun.core.structured.parsing import ParseStructurePython from langfun.core.structured.parsing import parse from langfun.core.structured.parsing import call -from langfun.core.structured.prompting import QueryStructure -from langfun.core.structured.prompting import QueryStructureJson -from langfun.core.structured.prompting import QueryStructurePython -from langfun.core.structured.prompting import query -from langfun.core.structured.prompting import query_prompt -from langfun.core.structured.prompting import query_output -from langfun.core.structured.prompting import query_reward -from langfun.core.structured.prompting import QueryInvocation -from langfun.core.structured.prompting import track_queries - -from langfun.core.structured.description import DescribeStructure -from langfun.core.structured.description import describe +from langfun.core.structured.querying import track_queries +from langfun.core.structured.querying import QueryInvocation +from langfun.core.structured.querying import query +from langfun.core.structured.querying import query_prompt +from langfun.core.structured.querying import query_output +from langfun.core.structured.querying import query_reward -from langfun.core.structured.completion import CompleteStructure +from langfun.core.structured.description import describe from langfun.core.structured.completion import complete from langfun.core.structured.scoring import score diff --git a/langfun/core/structured/completion.py b/langfun/core/structured/completion.py index 741d767..dda8f15 100644 --- a/langfun/core/structured/completion.py +++ b/langfun/core/structured/completion.py @@ -21,7 +21,7 @@ import pyglove as pg -class CompleteStructure(mapping.Mapping): +class _CompleteStructure(mapping.Mapping): """Complete structure by filling the missing fields.""" input: Annotated[ @@ -241,7 +241,7 @@ class Flight(pg.Object): Returns: The result based on the schema. """ - t = CompleteStructure( + t = _CompleteStructure( input=schema_lib.mark_missing(input_value), default=default, examples=examples, diff --git a/langfun/core/structured/completion_test.py b/langfun/core/structured/completion_test.py index 26a44f4..a93e644 100644 --- a/langfun/core/structured/completion_test.py +++ b/langfun/core/structured/completion_test.py @@ -46,7 +46,7 @@ class TripPlan(pg.Object): class CompleteStructureTest(unittest.TestCase): def test_render_no_examples(self): - l = completion.CompleteStructure() + l = completion._CompleteStructure() input_value = schema_lib.mark_missing( TripPlan.partial( place='San Francisco', @@ -120,7 +120,7 @@ class Activity: ) def test_render_no_class_definitions(self): - l = completion.CompleteStructure() + l = completion._CompleteStructure() input_value = schema_lib.mark_missing( TripPlan.partial( place='San Francisco', @@ -200,7 +200,7 @@ def test_render_no_class_definitions(self): ) def test_render_with_examples(self): - l = completion.CompleteStructure() + l = completion._CompleteStructure() input_value = schema_lib.mark_missing( TripPlan.partial( place='San Francisco', @@ -411,7 +411,7 @@ class Animal(pg.Object): modalities.Image.from_bytes(b'image_of_elephant'), ) ) - l = completion.CompleteStructure( + l = completion._CompleteStructure( input=input_value, examples=[ mapping.MappingExample( diff --git a/langfun/core/structured/description.py b/langfun/core/structured/description.py index 536fdc8..5d8d297 100644 --- a/langfun/core/structured/description.py +++ b/langfun/core/structured/description.py @@ -22,7 +22,7 @@ @pg.use_init_args(['examples']) -class DescribeStructure(mapping.Mapping): +class _DescribeStructure(mapping.Mapping): """Describe a structured value in natural language.""" input_title = 'PYTHON_OBJECT' @@ -106,7 +106,7 @@ class Flight(pg.Object): Returns: The parsed result based on the schema. """ - return DescribeStructure( + return _DescribeStructure( input=value, context=context, examples=examples or default_describe_examples(), diff --git a/langfun/core/structured/description_test.py b/langfun/core/structured/description_test.py index dba8dd2..4c74f76 100644 --- a/langfun/core/structured/description_test.py +++ b/langfun/core/structured/description_test.py @@ -36,7 +36,7 @@ class Itinerary(pg.Object): class DescribeStructureTest(unittest.TestCase): def test_render(self): - l = description_lib.DescribeStructure( + l = description_lib._DescribeStructure( input=Itinerary( day=1, type='daytime', @@ -137,7 +137,7 @@ def test_render_no_examples(self): ], hotel=None, ) - l = description_lib.DescribeStructure( + l = description_lib._DescribeStructure( input=value, context='1 day itinerary to SF' ) self.assertEqual( @@ -187,7 +187,7 @@ def test_render_no_context(self): ], hotel=None, ) - l = description_lib.DescribeStructure(input=value) + l = description_lib._DescribeStructure(input=value) self.assertEqual( l.render().text, inspect.cleandoc(""" diff --git a/langfun/core/structured/function_generation.py b/langfun/core/structured/function_generation.py index ef54380..92fdc1d 100644 --- a/langfun/core/structured/function_generation.py +++ b/langfun/core/structured/function_generation.py @@ -21,7 +21,7 @@ from langfun.core import language_model from langfun.core import template from langfun.core.coding import python -from langfun.core.structured import prompting +from langfun.core.structured import querying import pyglove as pg @@ -39,7 +39,7 @@ class PythonFunctionSignature(pg.Object): unittest_examples = None for _ in range(num_retries): - r = prompting.query( + r = querying.query( PythonFunctionSignature(signature=signature), list[UnitTest], lm=lm, @@ -145,7 +145,7 @@ def calculate_area_circle(radius: float) -> float: last_error = None for _ in range(num_retries): try: - source_code = prompting.query( + source_code = querying.query( PythonFunctionPrompt(signature=signature), lm=lm ) f = python.evaluate(source_code, global_vars=context) diff --git a/langfun/core/structured/parsing.py b/langfun/core/structured/parsing.py index 1b5689a..50b418f 100644 --- a/langfun/core/structured/parsing.py +++ b/langfun/core/structured/parsing.py @@ -16,13 +16,13 @@ import langfun.core as lf from langfun.core.structured import mapping -from langfun.core.structured import prompting +from langfun.core.structured import querying from langfun.core.structured import schema as schema_lib import pyglove as pg @lf.use_init_args(['schema', 'default', 'examples']) -class ParseStructure(mapping.Mapping): +class _ParseStructure(mapping.Mapping): """Parse an object out from a natural language text.""" context_title = 'USER_REQUEST' @@ -37,7 +37,7 @@ class ParseStructure(mapping.Mapping): ] -class ParseStructureJson(ParseStructure): +class _ParseStructureJson(_ParseStructure): """Parse an object out from a NL text using JSON as the protocol.""" preamble = """ @@ -53,7 +53,7 @@ class ParseStructureJson(ParseStructure): output_title = 'JSON' -class ParseStructurePython(ParseStructure): +class _ParseStructurePython(_ParseStructure): """Parse an object out from a NL text using Python as the protocol.""" preamble = """ @@ -87,7 +87,7 @@ def parse( returns_message: bool = False, **kwargs, ) -> Any: - """Parse a natural langugage message based on schema. + """Parse a natural language message based on schema. Examples: @@ -271,7 +271,7 @@ def call( return lm_output if returns_message else lm_output.text # Call `parsing_lm` for structured parsing. - parsing_message = prompting.query( + parsing_message = querying.query( lm_output.text, schema, examples=parsing_examples, @@ -293,11 +293,11 @@ def call( def _parse_structure_cls( protocol: schema_lib.SchemaProtocol, -) -> Type[ParseStructure]: +) -> Type[_ParseStructure]: if protocol == 'json': - return ParseStructureJson + return _ParseStructureJson elif protocol == 'python': - return ParseStructurePython + return _ParseStructurePython else: raise ValueError(f'Unknown protocol: {protocol!r}.') diff --git a/langfun/core/structured/parsing_test.py b/langfun/core/structured/parsing_test.py index 9466702..802d0bd 100644 --- a/langfun/core/structured/parsing_test.py +++ b/langfun/core/structured/parsing_test.py @@ -37,7 +37,7 @@ class Itinerary(pg.Object): class ParseStructurePythonTest(unittest.TestCase): def test_render_no_examples(self): - l = parsing.ParseStructurePython(int) + l = parsing._ParseStructurePython(int) m = lf.AIMessage('12 / 6 + 2 = 4') self.assertEqual( l.render(input=m, context='Compute 12 / 6 + 2.').text, @@ -62,7 +62,7 @@ def test_render_no_examples(self): ) def test_render_no_context(self): - l = parsing.ParseStructurePython(int) + l = parsing._ParseStructurePython(int) m = lf.AIMessage('12 / 6 + 2 = 4') self.assertEqual( @@ -85,7 +85,7 @@ def test_render_no_context(self): ) def test_render(self): - l = parsing.ParseStructurePython( + l = parsing._ParseStructurePython( int, examples=[ mapping.MappingExample( @@ -212,7 +212,7 @@ def test_invocation(self): ), override_attrs=True, ): - l = parsing.ParseStructurePython( + l = parsing._ParseStructurePython( [Itinerary], examples=[ mapping.MappingExample( @@ -295,7 +295,7 @@ def test_parse(self): class ParseStructureJsonTest(unittest.TestCase): def test_render_no_examples(self): - l = parsing.ParseStructureJson(int) + l = parsing._ParseStructureJson(int) m = lf.AIMessage('12 / 6 + 2 = 4') self.assertEqual( l.render(input=m, context='Compute 12 / 6 + 2.').text, @@ -320,7 +320,7 @@ def test_render_no_examples(self): ) def test_render_no_context(self): - l = parsing.ParseStructureJson(int) + l = parsing._ParseStructureJson(int) m = lf.AIMessage('12 / 6 + 2 = 4') self.assertEqual( @@ -343,7 +343,7 @@ def test_render_no_context(self): ) def test_render(self): - l = parsing.ParseStructureJson( + l = parsing._ParseStructureJson( int, examples=[ mapping.MappingExample( @@ -504,7 +504,7 @@ def test_invocation(self): override_attrs=True, ): message = lf.LangFunc(lm_input)() - l = parsing.ParseStructureJson( + l = parsing._ParseStructureJson( [Itinerary], examples=[ mapping.MappingExample( diff --git a/langfun/core/structured/prompting.py b/langfun/core/structured/querying.py similarity index 97% rename from langfun/core/structured/prompting.py rename to langfun/core/structured/querying.py index 5005992..b85a632 100644 --- a/langfun/core/structured/prompting.py +++ b/langfun/core/structured/querying.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Symbolic query.""" +"""Query LLM for structured output.""" import contextlib import functools @@ -26,7 +26,7 @@ @lf.use_init_args(['schema', 'default', 'examples']) -class QueryStructure(mapping.Mapping): +class _QueryStructure(mapping.Mapping): """Query an object out from a natural language text.""" context_title = 'CONTEXT' @@ -38,7 +38,7 @@ class QueryStructure(mapping.Mapping): ] -class QueryStructureJson(QueryStructure): +class _QueryStructureJson(_QueryStructure): """Query a structured value using JSON as the protocol.""" preamble = """ @@ -52,10 +52,10 @@ class QueryStructureJson(QueryStructure): 1 + 1 = {{ schema_title }}: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}} {{ output_title}}: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}} """ protocol = 'json' @@ -63,7 +63,7 @@ class QueryStructureJson(QueryStructure): output_title = 'JSON' -class QueryStructurePython(QueryStructure): +class _QueryStructurePython(_QueryStructure): """Query a structured value using Python as the protocol.""" preamble = """ @@ -94,11 +94,11 @@ class Answer: def _query_structure_cls( protocol: schema_lib.SchemaProtocol, -) -> Type[QueryStructure]: +) -> Type[_QueryStructure]: if protocol == 'json': - return QueryStructureJson + return _QueryStructureJson elif protocol == 'python': - return QueryStructurePython + return _QueryStructurePython else: raise ValueError(f'Unknown protocol: {protocol!r}.') diff --git a/langfun/core/structured/prompting_test.py b/langfun/core/structured/querying_test.py similarity index 91% rename from langfun/core/structured/prompting_test.py rename to langfun/core/structured/querying_test.py index 2e54b1d..ecc7678 100644 --- a/langfun/core/structured/prompting_test.py +++ b/langfun/core/structured/querying_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for structured prompting.""" +"""Tests for structured query.""" import inspect import math @@ -23,7 +23,7 @@ from langfun.core.llms import fake from langfun.core.llms.cache import in_memory from langfun.core.structured import mapping -from langfun.core.structured import prompting +from langfun.core.structured import querying import pyglove as pg @@ -51,7 +51,7 @@ def assert_render( expected_modalities: int = 0, **kwargs, ): - m = prompting.query( + m = querying.query( prompt, schema=schema, examples=examples, **kwargs, returns_message=True ) @@ -67,14 +67,14 @@ def assert_render( def test_call(self): lm = fake.StaticSequence(['1']) - self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm), 1) + self.assertEqual(querying.query('what is 1 + 0', int, lm=lm), 1) # Testing calling the same `lm` without copy. with self.assertRaises(IndexError): - prompting.query('what is 1 + 2', int, lm=lm) + querying.query('what is 1 + 2', int, lm=lm) self.assertEqual( - prompting.query( + querying.query( 'what is 1 + 0', int, lm=lm.clone(), returns_message=True ), lf.AIMessage( @@ -88,17 +88,17 @@ def test_call(self): ), ) self.assertEqual( - prompting.query( + querying.query( lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone() ), 1, ) self.assertEqual( - prompting.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()), + querying.query('what is {{x}} + {{y}}', int, x=1, y=0, lm=lm.clone()), 1, ) self.assertEqual( - prompting.query( + querying.query( 'what is {{x}} + {{y}}', x=1, y=0, @@ -107,7 +107,7 @@ def test_call(self): 'The answer is one.', ) self.assertEqual( - prompting.query( + querying.query( Activity.partial(), lm=fake.StaticResponse('Activity(description="hello")'), ), @@ -329,11 +329,11 @@ def test_structure_with_modality_and_examples_to_structure_render(self): def test_bad_protocol(self): with self.assertRaisesRegex(ValueError, 'Unknown protocol'): - prompting.query('what is 1 + 1', int, protocol='text') + querying.query('what is 1 + 1', int, protocol='text') def test_query_prompt(self): self.assertEqual( - prompting.query_prompt('what is this?', int), + querying.query_prompt('what is this?', int), inspect.cleandoc(""" Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE. @@ -368,14 +368,14 @@ class Answer: def test_query_prompt_with_metadata(self): self.assertIn( 'x', - prompting.query_prompt( + querying.query_prompt( 'what is this?', metadata_x=1 ).metadata ) self.assertIn( 'x', - prompting.query_prompt( + querying.query_prompt( 'what is this?', int, metadata_x=1 @@ -383,7 +383,7 @@ def test_query_prompt_with_metadata(self): ) def test_query_prompt_with_unrooted_template(self): - output = prompting.query_prompt( + output = querying.query_prompt( pg.Dict( input=lf.Template( 'what is {{image}}', @@ -395,7 +395,7 @@ def test_query_prompt_with_unrooted_template(self): def test_query_output(self): self.assertEqual( - prompting.query_output( + querying.query_output( lf.AIMessage('1'), int, ), @@ -414,7 +414,7 @@ def __reward__(self, inputs: lf.Template) -> None: # Case 1: Reward function based on input and output. self.assertEqual( - prompting.query_reward( + querying.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), schema=Answer, @@ -425,7 +425,7 @@ def __reward__(self, inputs: lf.Template) -> None: 1.0 ) self.assertEqual( - prompting.query_reward( + querying.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=2, y=3), output=Answer(final_answer=2), @@ -445,7 +445,7 @@ def __reward__(self, inputs: lf.Template, expected_output: 'Answer2'): ) self.assertEqual( - prompting.query_reward( + querying.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output=Answer2(final_answer=2), @@ -470,7 +470,7 @@ def __reward__(self, ) * metadata['weight'] self.assertEqual( - prompting.query_reward( + querying.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output=Answer3(final_answer=2), @@ -486,7 +486,7 @@ class Answer4(pg.Object): final_answer: int self.assertIsNone( - prompting.query_reward( + querying.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output=Answer4(final_answer=2), @@ -497,7 +497,7 @@ class Answer4(pg.Object): # Case 5: Not a structured output. self.assertIsNone( - prompting.query_reward( + querying.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output='2', @@ -516,7 +516,7 @@ def __reward__(self): with self.assertRaisesRegex( TypeError, '.*Answer5.__reward__` should have signature' ): - prompting.query_reward( + querying.query_reward( mapping.MappingExample( input=lf.Template('{{x}} + {{y}}', x=1, y=1), output=Answer5(final_answer=2), @@ -528,7 +528,7 @@ def __reward__(self): class QueryStructurePythonTest(unittest.TestCase): def test_render_no_examples(self): - l = prompting.QueryStructurePython( + l = querying._QueryStructurePython( input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int ) self.assertEqual( @@ -565,7 +565,7 @@ class Answer: ) def test_render(self): - l = prompting.QueryStructurePython( + l = querying._QueryStructurePython( input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int, examples=[ @@ -675,7 +675,7 @@ def test_invocation(self): ), override_attrs=True, ): - l = prompting.QueryStructurePython( + l = querying._QueryStructurePython( input=lm_input, schema=[Itinerary], examples=[ @@ -712,7 +712,7 @@ def test_bad_response(self): mapping.MappingError, 'name .* is not defined', ): - prompting.query('Compute 1 + 2', int) + querying.query('Compute 1 + 2', int) def test_autofix(self): lm = fake.StaticSequence([ @@ -723,7 +723,7 @@ def test_autofix(self): ) """), ]) - self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1) + self.assertEqual(querying.query('what is 1 + 0', int, lm=lm, autofix=3), 1) def test_response_postprocess(self): with lf.context( @@ -731,12 +731,12 @@ def test_response_postprocess(self): override_attrs=True, ): self.assertEqual( - prompting.query( + querying.query( 'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]), '3' ) self.assertEqual( - prompting.query( + querying.query( 'Compute 1 + 2', int, response_postprocess=lambda x: x.split('\n')[1]), 3 @@ -746,7 +746,7 @@ def test_response_postprocess(self): class QueryStructureJsonTest(unittest.TestCase): def test_render_no_examples(self): - l = prompting.QueryStructureJson( + l = querying._QueryStructureJson( input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int ) self.assertEqual( @@ -762,10 +762,10 @@ def test_render_no_examples(self): 1 + 1 = SCHEMA: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}} JSON: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}} INPUT_OBJECT: Compute 12 / 6 + 2. @@ -778,7 +778,7 @@ def test_render_no_examples(self): ) def test_render(self): - l = prompting.QueryStructureJson( + l = querying._QueryStructureJson( input=lf.AIMessage('Compute 12 / 6 + 2.'), schema=int, examples=[ @@ -799,10 +799,10 @@ def test_render(self): 1 + 1 = SCHEMA: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}} JSON: - {"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}} + {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}} INPUT_OBJECT: What is the answer of 1 plus 1? @@ -913,7 +913,7 @@ def test_invocation(self): ), override_attrs=True, ): - l = prompting.QueryStructureJson( + l = querying._QueryStructureJson( input=lm_input, schema=[Itinerary], examples=[ @@ -951,14 +951,14 @@ def test_bad_transform(self): mapping.MappingError, 'No JSON dict in the output', ): - prompting.query('Compute 1 + 2', int, protocol='json', cache_seed=1) + querying.query('Compute 1 + 2', int, protocol='json', cache_seed=1) # Make sure bad mapping does not impact cache. self.assertEqual(len(cache), 0) def test_query(self): lm = fake.StaticSequence(['{"result": 1}']) self.assertEqual( - prompting.query('what is 1 + 0', int, lm=lm, protocol='json'), 1 + querying.query('what is 1 + 0', int, lm=lm, protocol='json'), 1 ) @@ -968,8 +968,8 @@ def test_to_html(self): lm = fake.StaticSequence([ 'Activity(description="hi")', ]) - with prompting.track_queries() as queries: - prompting.query('foo', Activity, lm=lm) + with querying.track_queries() as queries: + querying.query('foo', Activity, lm=lm) self.assertIn('schema', queries[0].to_html_str()) @@ -981,10 +981,10 @@ def test_include_child_scopes(self): 'bar', 'Activity(description="hi")', ]) - with prompting.track_queries() as queries: - prompting.query('foo', lm=lm) - with prompting.track_queries() as child_queries: - prompting.query('give me an activity', Activity, lm=lm) + with querying.track_queries() as queries: + querying.query('foo', lm=lm) + with querying.track_queries() as child_queries: + querying.query('give me an activity', Activity, lm=lm) self.assertEqual(len(queries), 2) self.assertTrue(pg.eq(queries[0].input, lf.Template('foo'))) @@ -1008,10 +1008,10 @@ def test_exclude_child_scopes(self): 'bar', 'Activity(description="hi")', ]) - with prompting.track_queries(include_child_scopes=False) as queries: - prompting.query('foo', lm=lm) - with prompting.track_queries(include_child_scopes=False) as child_queries: - prompting.query('give me an activity', Activity, lm=lm) + with querying.track_queries(include_child_scopes=False) as queries: + querying.query('foo', lm=lm) + with querying.track_queries(include_child_scopes=False) as child_queries: + querying.query('give me an activity', Activity, lm=lm) self.assertEqual(len(queries), 1) self.assertTrue(pg.eq(queries[0].input, lf.Template('foo'))) @@ -1030,13 +1030,13 @@ def test_exclude_child_scopes(self): def test_concurrent_map(self): def make_query(prompt): - _ = prompting.query(prompt, lm=lm) + _ = querying.query(prompt, lm=lm) lm = fake.StaticSequence([ 'foo', 'bar', ]) - with prompting.track_queries() as queries: + with querying.track_queries() as queries: list(lf.concurrent_map(make_query, ['a', 'b'])) self.assertEqual(len(queries), 2) diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index ad30c23..c1677d9 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -388,9 +388,9 @@ def result_definition(self, schema: Schema) -> str: return annotation(schema.spec) -def source_form(value, markdown: bool = False) -> str: +def source_form(value, compact: bool = True, markdown: bool = False) -> str: """Returns the source code form of an object.""" - return ValuePythonRepr().repr(value, markdown=markdown) + return ValuePythonRepr().repr(value, compact=compact, markdown=markdown) def class_definitions( @@ -789,7 +789,7 @@ def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any: """Parse a JSON string into a structured object.""" del schema try: - text = self.cleanup_json(text) + text = cleanup_json(text) v = pg.from_json_str(text, **kwargs) except Exception as e: raise JsonError(text, e) # pylint: disable=raise-missing-from @@ -801,55 +801,56 @@ def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any: )) return v['result'] - def cleanup_json(self, json_str: str) -> str: - """Clean up the LM responded JSON string.""" - # Treatments: - # 1. Extract the JSON string with a top-level dict from the response. - # This prevents the leading and trailing texts in the response to - # be counted as part of the JSON. - # 2. Escape new lines in JSON values. - - curly_brackets = 0 - under_json = False - under_str = False - str_begin = -1 - - cleaned = io.StringIO() - for i, c in enumerate(json_str): - if c == '{' and not under_str: - cleaned.write(c) - curly_brackets += 1 - under_json = True - continue - elif not under_json: - continue - if c == '}' and not under_str: - cleaned.write(c) - curly_brackets -= 1 - if curly_brackets == 0: - break - elif c == '"' and json_str[i - 1] != '\\': - under_str = not under_str - if under_str: - str_begin = i - else: - assert str_begin > 0 - str_value = json_str[str_begin : i + 1].replace('\n', '\\n') - cleaned.write(str_value) - str_begin = -1 - elif not under_str: - cleaned.write(c) - - if not under_json: - raise ValueError(f'No JSON dict in the output: {json_str}') - - if curly_brackets > 0: - raise ValueError( - f'Malformated JSON: missing {curly_brackets} closing curly braces.' - ) +def cleanup_json(json_str: str) -> str: + """Clean up the LM responded JSON string.""" + # Treatments: + # 1. Extract the JSON string with a top-level dict from the response. + # This prevents the leading and trailing texts in the response to + # be counted as part of the JSON. + # 2. Escape new lines in JSON values. + + curly_brackets = 0 + under_json = False + under_str = False + str_begin = -1 + + cleaned = io.StringIO() + for i, c in enumerate(json_str): + if c == '{' and not under_str: + cleaned.write(c) + curly_brackets += 1 + under_json = True + continue + elif not under_json: + continue + + if c == '}' and not under_str: + cleaned.write(c) + curly_brackets -= 1 + if curly_brackets == 0: + break + elif c == '"' and json_str[i - 1] != '\\': + under_str = not under_str + if under_str: + str_begin = i + else: + assert str_begin > 0 + str_value = json_str[str_begin : i + 1].replace('\n', '\\n') + cleaned.write(str_value) + str_begin = -1 + elif not under_str: + cleaned.write(c) + + if not under_json: + raise ValueError(f'No JSON dict in the output: {json_str}') + + if curly_brackets > 0: + raise ValueError( + f'Malformated JSON: missing {curly_brackets} closing curly braces.' + ) - return cleaned.getvalue() + return cleaned.getvalue() def schema_repr(protocol: SchemaProtocol) -> SchemaRepr: diff --git a/langfun/core/structured/scoring.py b/langfun/core/structured/scoring.py index d7bd952..fe46907 100644 --- a/langfun/core/structured/scoring.py +++ b/langfun/core/structured/scoring.py @@ -17,7 +17,7 @@ import langfun.core as lf from langfun.core.structured import mapping -from langfun.core.structured import prompting +from langfun.core.structured import querying from langfun.core.structured import schema as schema_lib import pyglove as pg @@ -101,7 +101,7 @@ class Answer(pg.Object): prompts = [] for p in prompt: prompts.append( - prompting.query_prompt( + querying.query_prompt( p, schema, examples=examples, @@ -111,7 +111,7 @@ class Answer(pg.Object): ) input_message = prompts else: - input_message = prompting.query_prompt( + input_message = querying.query_prompt( prompt, schema, examples=examples, diff --git a/langfun/core/structured/tokenization.py b/langfun/core/structured/tokenization.py index 14c1248..5407f8d 100644 --- a/langfun/core/structured/tokenization.py +++ b/langfun/core/structured/tokenization.py @@ -17,7 +17,7 @@ import langfun.core as lf from langfun.core.structured import mapping -from langfun.core.structured import prompting +from langfun.core.structured import querying from langfun.core.structured import schema as schema_lib import pyglove as pg @@ -48,7 +48,7 @@ def tokenize( Returns: A list of (text, token_id) tuples. """ - input_message = prompting.query_prompt( + input_message = querying.query_prompt( prompt, schema, examples=examples,