From 576bba2bfb72540be8ef35f5f39bb0df69f23c54 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Tue, 26 Sep 2023 12:24:29 -0700 Subject: [PATCH] Restructure `lf.structured` files. The goal is to make this directory more extensible and readable. PiperOrigin-RevId: 568616171 --- langfun/core/structured/__init__.py | 35 +- .../{structure2structure.py => completion.py} | 117 +---- ...e2structure_test.py => completion_test.py} | 36 +- .../{structure2nl.py => description.py} | 65 +-- ...ructure2nl_test.py => description_test.py} | 10 +- langfun/core/structured/mapping.py | 290 ++++++++++++- langfun/core/structured/mapping_test.py | 28 ++ .../{nl2structure.py => parsing.py} | 274 +----------- .../{nl2structure_test.py => parsing_test.py} | 398 ++---------------- langfun/core/structured/query.py | 183 ++++++++ langfun/core/structured/query_test.py | 390 +++++++++++++++++ 11 files changed, 960 insertions(+), 866 deletions(-) rename langfun/core/structured/{structure2structure.py => completion.py} (52%) rename langfun/core/structured/{structure2structure_test.py => completion_test.py} (92%) rename langfun/core/structured/{structure2nl.py => description.py} (74%) rename langfun/core/structured/{structure2nl_test.py => description_test.py} (96%) rename langfun/core/structured/{nl2structure.py => parsing.py} (55%) rename langfun/core/structured/{nl2structure_test.py => parsing_test.py} (57%) create mode 100644 langfun/core/structured/query.py create mode 100644 langfun/core/structured/query_test.py diff --git a/langfun/core/structured/__init__.py b/langfun/core/structured/__init__.py index e4b7c649..d64cd0e6 100644 --- a/langfun/core/structured/__init__.py +++ b/langfun/core/structured/__init__.py @@ -45,23 +45,28 @@ from langfun.core.structured.mapping import MappingExample from langfun.core.structured.mapping import MappingError -from langfun.core.structured.nl2structure import NaturalLanguageToStructure -from langfun.core.structured.nl2structure import ParseStructure -from langfun.core.structured.nl2structure import ParseStructureJson -from langfun.core.structured.nl2structure import ParseStructurePython -from langfun.core.structured.nl2structure import QueryStructure -from langfun.core.structured.nl2structure import QueryStructureJson -from langfun.core.structured.nl2structure import QueryStructurePython -from langfun.core.structured.nl2structure import parse -from langfun.core.structured.nl2structure import query +# Mappings of between different forms of content. +from langfun.core.structured.mapping import NaturalLanguageToStructure +from langfun.core.structured.mapping import StructureToNaturalLanguage +from langfun.core.structured.mapping import StructureToStructure -from langfun.core.structured.structure2nl import StructureToNaturalLanguage -from langfun.core.structured.structure2nl import DescribeStructure -from langfun.core.structured.structure2nl import describe +from langfun.core.structured.parsing import ParseStructure +from langfun.core.structured.parsing import ParseStructureJson +from langfun.core.structured.parsing import ParseStructurePython +from langfun.core.structured.parsing import parse -from langfun.core.structured.structure2structure import StructureToStructure -from langfun.core.structured.structure2structure import CompleteStructure -from langfun.core.structured.structure2structure import complete +import langfun.core.structured.query as query_lib + +from langfun.core.structured.query import QueryStructure +from langfun.core.structured.query import QueryStructureJson +from langfun.core.structured.query import QueryStructurePython +from langfun.core.structured.query import query + +from langfun.core.structured.description import DescribeStructure +from langfun.core.structured.description import describe + +from langfun.core.structured.completion import CompleteStructure +from langfun.core.structured.completion import complete # pylint: enable=g-importing-member diff --git a/langfun/core/structured/structure2structure.py b/langfun/core/structured/completion.py similarity index 52% rename from langfun/core/structured/structure2structure.py rename to langfun/core/structured/completion.py index 00ff5142..900f5ec2 100644 --- a/langfun/core/structured/structure2structure.py +++ b/langfun/core/structured/completion.py @@ -13,7 +13,7 @@ # limitations under the License. """Structure-to-structure mappings.""" -from typing import Annotated, Any, Literal, Type +from typing import Any, Literal import langfun.core as lf from langfun.core.structured import mapping @@ -21,118 +21,7 @@ import pyglove as pg -class Pair(pg.Object): - """Value pair used for expressing structure-to-structure mapping.""" - - left: pg.typing.Annotated[ - pg.typing.Any(transform=schema_lib.mark_missing), 'The left-side value.' - ] - right: pg.typing.Annotated[ - pg.typing.Any(transform=schema_lib.mark_missing), 'The right-side value.' - ] - - -class StructureToStructure(mapping.Mapping): - """Base class for structure-to-structure mapping. - - {{ preamble }} - - {% if examples -%} - {% for example in examples -%} - {{ input_value_title }}: - {{ value_str(example.value.left) | indent(2, True) }} - - {%- if missing_type_dependencies(example.value) %} - - {{ type_definitions_title }}: - {{ type_definitions_str(example.value) | indent(2, True) }} - {%- endif %} - - {{ output_value_title }}: - {{ value_str(example.value.right) | indent(2, True) }} - - {% endfor %} - {% endif -%} - {{ input_value_title }}: - {{ value_str(input_value) | indent(2, True) }} - {%- if missing_type_dependencies(input_value) %} - - {{ type_definitions_title }}: - {{ type_definitions_str(input_value) | indent(2, True) }} - {%- endif %} - - {{ output_value_title }}: - """ - - default: Annotated[ - Any, - ( - 'The default value to use if mapping failed. ' - 'If unspecified, error will be raisen.' - ), - ] = lf.message_transform.RAISE_IF_HAS_ERROR - - preamble: Annotated[ - lf.LangFunc, - 'Preamble used for structure-to-structure mapping.', - ] - - type_definitions_title: Annotated[ - str, 'The section title for type definitions.' - ] = 'CLASS_DEFINITIONS' - - input_value_title: Annotated[str, 'The section title for input value.'] - output_value_title: Annotated[str, 'The section title for output value.'] - - def _on_bound(self): - super()._on_bound() - if self.examples: - for example in self.examples: - if not isinstance(example.value, Pair): - raise ValueError( - 'The value of example must be a `lf.structured.Pair` object. ' - f'Encountered: { example.value }.' - ) - - @property - def input_value(self) -> Any: - return schema_lib.mark_missing(self.message.result) - - def value_str(self, value: Any) -> str: - return schema_lib.value_repr('python').repr(value, compact=False) - - def missing_type_dependencies(self, value: Any) -> list[Type[Any]]: - value_specs = tuple( - [v.value_spec for v in schema_lib.Missing.find_missing(value).values()] - ) - return schema_lib.class_dependencies(value_specs, include_subclasses=True) - - def type_definitions_str(self, value: Any) -> str | None: - return schema_lib.class_definitions( - self.missing_type_dependencies(value), markdown=True - ) - - def _value_context(self): - classes = schema_lib.class_dependencies(self.input_value) - return {cls.__name__: cls for cls in classes} - - def transform_output(self, lm_output: lf.Message) -> lf.Message: - try: - result = schema_lib.value_repr('python').parse( - lm_output.text, additional_context=self._value_context() - ) - except Exception as e: # pylint: disable=broad-exception-caught - if self.default == lf.message_transform.RAISE_IF_HAS_ERROR: - raise mapping.MappingError( - 'Cannot parse message text into structured output. ' - f'Error={e}. Text={lm_output.text!r}.' - ) from e - result = self.default - lm_output.result = result - return lm_output - - -class CompleteStructure(StructureToStructure): +class CompleteStructure(mapping.StructureToStructure): """Complete structure by filling the missing fields.""" preamble = lf.LangFunc(""" @@ -173,7 +62,7 @@ class _Country(pg.Object): def _default_complete_examples() -> list[mapping.MappingExample]: return [ mapping.MappingExample( - value=Pair( + value=mapping.Pair( left=_Country.partial(name='United States of America'), right=_Country( name='United States of America', diff --git a/langfun/core/structured/structure2structure_test.py b/langfun/core/structured/completion_test.py similarity index 92% rename from langfun/core/structured/structure2structure_test.py rename to langfun/core/structured/completion_test.py index 0134887d..dfd8f8bc 100644 --- a/langfun/core/structured/structure2structure_test.py +++ b/langfun/core/structured/completion_test.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for langfun.core.structured.structure2structure.""" +"""Tests for langfun.core.structured.completion.""" import inspect import unittest + import langfun.core as lf from langfun.core.llms import fake +from langfun.core.structured import completion from langfun.core.structured import mapping -from langfun.core.structured import schema as schema_lib -from langfun.core.structured import structure2structure import pyglove as pg @@ -39,22 +39,10 @@ class TripPlan(pg.Object): itineraries: list[Itinerary] -class PairTest(unittest.TestCase): - - def test_partial(self): - p = structure2structure.Pair( - TripPlan.partial(place='San Francisco'), - TripPlan.partial(itineraries=[Itinerary.partial(day=1)]), - ) - self.assertEqual(p.left.itineraries, schema_lib.MISSING) - self.assertEqual(p.right.place, schema_lib.MISSING) - self.assertEqual(p.right.itineraries[0].activities, schema_lib.MISSING) - - class CompleteStructureTest(unittest.TestCase): def test_render_no_examples(self): - l = structure2structure.CompleteStructure() + l = completion.CompleteStructure() m = lf.UserMessage( '', result=TripPlan.partial( @@ -115,7 +103,7 @@ class Activity: ) def test_render_no_class_definitions(self): - l = structure2structure.CompleteStructure() + l = completion.CompleteStructure() m = lf.UserMessage( '', result=TripPlan.partial( @@ -182,8 +170,8 @@ def test_render_no_class_definitions(self): ) def test_render_with_examples(self): - l = structure2structure.CompleteStructure( - examples=structure2structure._default_complete_examples() + l = completion.CompleteStructure( + examples=completion._default_complete_examples() ) m = lf.UserMessage( '', @@ -323,7 +311,7 @@ def test_transform(self): ), override_attrs=True, ): - r = structure2structure.complete( + r = completion.complete( TripPlan.partial( place='San Francisco', itineraries=[ @@ -413,9 +401,7 @@ def test_transform(self): def test_bad_init(self): with self.assertRaisesRegex(ValueError, '.*must be.*Pair'): - structure2structure.CompleteStructure( - examples=[mapping.MappingExample(value=1)] - ) + completion.CompleteStructure(examples=[mapping.MappingExample(value=1)]) def test_bad_transform(self): with lf.context( @@ -426,14 +412,14 @@ def test_bad_transform(self): mapping.MappingError, 'Cannot parse message text into structured output', ): - structure2structure.complete(Activity.partial()) + completion.complete(Activity.partial()) def test_default(self): with lf.context( lm=fake.StaticSequence(['Activity(description=1)']), override_attrs=True, ): - self.assertIsNone(structure2structure.complete(Activity.partial(), None)) + self.assertIsNone(completion.complete(Activity.partial(), None)) if __name__ == '__main__': diff --git a/langfun/core/structured/structure2nl.py b/langfun/core/structured/description.py similarity index 74% rename from langfun/core/structured/structure2nl.py rename to langfun/core/structured/description.py index b448ad75..76aa63fa 100644 --- a/langfun/core/structured/structure2nl.py +++ b/langfun/core/structured/description.py @@ -14,76 +14,15 @@ """Structured value to natural language.""" import inspect -from typing import Annotated, Any, Literal +from typing import Any, Literal import langfun.core as lf from langfun.core.structured import mapping -from langfun.core.structured import schema as schema_lib import pyglove as pg -class StructureToNaturalLanguage(mapping.Mapping): - """LangFunc for converting a structured value to natural language. - - {{ preamble }} - - {% if examples -%} - {% for example in examples -%} - {%- if example.nl_context -%} - {{ nl_context_title}}: - {{ example.nl_context | indent(2, True)}} - - {% endif -%} - {{ value_title}}: - {{ value_str(example.value) | indent(2, True) }} - - {{ nl_text_title }}: - {{ example.nl_text | indent(2, True) }} - - {% endfor %} - {% endif -%} - {% if nl_context -%} - {{ nl_context_title }}: - {{ nl_context | indent(2, True)}} - - {% endif -%} - {{ value_title }}: - {{ value_str(value) | indent(2, True) }} - - {{ nl_text_title }}: - """ - - preamble: Annotated[ - lf.LangFunc, 'Preamble used for zeroshot natural language mapping.' - ] - - nl_context_title: Annotated[str, 'The section title for nl_context.'] = ( - 'CONTEXT_FOR_DESCRIPTION' - ) - - nl_text_title: Annotated[str, 'The section title for nl_text.'] = ( - 'NATURAL_LANGUAGE_TEXT' - ) - - value_title: Annotated[str, 'The section title for schema.'] = 'PYTHON_OBJECT' - - @property - def value(self) -> Any: - """Returns the structured input value.""" - return self.message.result - - @property - def nl_context(self) -> str: - """Returns the context information for the description.""" - return self.message.text - - def value_str(self, value: Any) -> str: - return schema_lib.value_repr('python').repr( - value, markdown=False, compact=False) - - @pg.use_init_args(['examples']) -class DescribeStructure(StructureToNaturalLanguage): +class DescribeStructure(mapping.StructureToNaturalLanguage): """Describe a structured value in natural language.""" preamble = """ diff --git a/langfun/core/structured/structure2nl_test.py b/langfun/core/structured/description_test.py similarity index 96% rename from langfun/core/structured/structure2nl_test.py rename to langfun/core/structured/description_test.py index ac3a13c1..57159625 100644 --- a/langfun/core/structured/structure2nl_test.py +++ b/langfun/core/structured/description_test.py @@ -18,8 +18,8 @@ import langfun.core as lf from langfun.core.llms import fake +from langfun.core.structured import description as description_lib from langfun.core.structured import mapping -from langfun.core.structured import structure2nl import pyglove as pg @@ -37,7 +37,7 @@ class Itinerary(pg.Object): class DescribeStructureTest(unittest.TestCase): def test_render(self): - l = structure2nl.DescribeStructure( + l = description_lib.DescribeStructure( examples=[ mapping.MappingExample( nl_context='Compute 1 + 2', @@ -124,7 +124,7 @@ def test_render(self): ) def test_render_no_examples(self): - l = structure2nl.DescribeStructure() + l = description_lib.DescribeStructure() m = lf.UserMessage( '1 day itinerary to SF', result=Itinerary( @@ -174,7 +174,7 @@ def test_render_no_examples(self): ) def test_render_no_context(self): - l = structure2nl.DescribeStructure() + l = description_lib.DescribeStructure() m = lf.UserMessage( '', result=Itinerary( @@ -230,7 +230,7 @@ def test_describe(self): ] ) self.assertEqual( - structure2nl.describe( + description_lib.describe( Itinerary( day=1, type='daytime', diff --git a/langfun/core/structured/mapping.py b/langfun/core/structured/mapping.py index 8df144c7..809ffbab 100644 --- a/langfun/core/structured/mapping.py +++ b/langfun/core/structured/mapping.py @@ -13,8 +13,9 @@ # limitations under the License. """Mapping interfaces.""" +import abc import io -from typing import Annotated +from typing import Annotated, Any, Type import langfun.core as lf from langfun.core.structured import schema as schema_lib import pyglove as pg @@ -135,3 +136,290 @@ class Mapping(lf.LangFunc): list[MappingExample] | None, 'Fewshot examples for improving the quality of mapping.' ] = lf.contextual(default=None) + + +class NaturalLanguageToStructure(Mapping): + """LangFunc for converting natural language text to structured value. + + {{ preamble }} + + {% if examples -%} + {% for example in examples -%} + {%- if example.nl_context -%} + {{ nl_context_title}}: + {{ example.nl_context | indent(2, True)}} + + {% endif -%} + {%- if example.nl_text -%} + {{ nl_text_title }}: + {{ example.nl_text | indent(2, True) }} + + {% endif -%} + {{ schema_title }}: + {{ example.schema_str(protocol) | indent(2, True) }} + + {{ value_title }}: + {{ example.value_str(protocol) | indent(2, True) }} + + {% endfor %} + {% endif -%} + {% if nl_context -%} + {{ nl_context_title }}: + {{ nl_context | indent(2, True)}} + + {% endif -%} + {% if nl_text -%} + {{ nl_text_title }}: + {{ nl_text | indent(2, True) }} + + {% endif -%} + {{ schema_title }}: + {{ schema.schema_str(protocol) | indent(2, True) }} + + {{ value_title }}: + """ + + schema: pg.typing.Annotated[ + # Automatic conversion from annotation to schema. + schema_lib.schema_spec(), + 'A `lf.structured.Schema` that constrains the structured value.', + ] + + default: Annotated[ + Any, + ( + 'The default value to use if parsing failed. ' + 'If unspecified, error will be raisen.' + ), + ] = lf.message_transform.RAISE_IF_HAS_ERROR + + preamble: Annotated[ + lf.LangFunc, + 'Preamble used for natural language-to-structure mapping.', + ] + + nl_context_title: Annotated[str, 'The section title for nl_context.'] = ( + 'USER_REQUEST' + ) + + nl_text_title: Annotated[str, 'The section title for nl_text.'] = ( + 'LM_RESPONSE' + ) + + schema_title: Annotated[str, 'The section title for schema.'] + + value_title: Annotated[str, 'The section title for schema.'] + + protocol: Annotated[ + schema_lib.SchemaProtocol, + 'The protocol for representing the schema and value.', + ] + + @property + @abc.abstractmethod + def nl_context(self) -> str | None: + """Returns the natural language context for obtaining the response. + + Returns: + The natural language context (prompt) for obtaining the response (either + in natural language or directly to structured protocol). If None, + `nl_text` + must be provided. + """ + + @property + @abc.abstractmethod + def nl_text(self) -> str | None: + """Returns the natural language text to map. + + Returns: + The natural language text (in LM response) to map to object. If None, + the LM directly outputs structured protocol instead of natural language. + If None, `nl_context` must be provided. + """ + + def transform_output(self, lm_output: lf.Message) -> lf.Message: + try: + lm_output.result = self.schema.parse( + lm_output.text, protocol=self.protocol + ) + except Exception as e: # pylint: disable=broad-exception-caught + if self.default == lf.message_transform.RAISE_IF_HAS_ERROR: + raise MappingError( + 'Cannot parse message text into structured output. ' + f'Error={e}. Text={lm_output.text!r}.' + ) from e + lm_output.result = self.default + return lm_output + + +class StructureToNaturalLanguage(Mapping): + """LangFunc for converting a structured value to natural language. + + {{ preamble }} + + {% if examples -%} + {% for example in examples -%} + {%- if example.nl_context -%} + {{ nl_context_title}}: + {{ example.nl_context | indent(2, True)}} + + {% endif -%} + {{ value_title}}: + {{ value_str(example.value) | indent(2, True) }} + + {{ nl_text_title }}: + {{ example.nl_text | indent(2, True) }} + + {% endfor %} + {% endif -%} + {% if nl_context -%} + {{ nl_context_title }}: + {{ nl_context | indent(2, True)}} + + {% endif -%} + {{ value_title }}: + {{ value_str(value) | indent(2, True) }} + + {{ nl_text_title }}: + """ + + preamble: Annotated[ + lf.LangFunc, 'Preamble used for zeroshot natural language mapping.' + ] + + nl_context_title: Annotated[str, 'The section title for nl_context.'] = ( + 'CONTEXT_FOR_DESCRIPTION' + ) + + nl_text_title: Annotated[str, 'The section title for nl_text.'] = ( + 'NATURAL_LANGUAGE_TEXT' + ) + + value_title: Annotated[str, 'The section title for schema.'] = 'PYTHON_OBJECT' + + @property + def value(self) -> Any: + """Returns the structured input value.""" + return self.message.result + + @property + def nl_context(self) -> str: + """Returns the context information for the description.""" + return self.message.text + + def value_str(self, value: Any) -> str: + return schema_lib.value_repr('python').repr( + value, markdown=False, compact=False + ) + + +class Pair(pg.Object): + """Value pair used for expressing structure-to-structure mapping.""" + + left: pg.typing.Annotated[ + pg.typing.Any(transform=schema_lib.mark_missing), 'The left-side value.' + ] + right: pg.typing.Annotated[ + pg.typing.Any(transform=schema_lib.mark_missing), 'The right-side value.' + ] + + +class StructureToStructure(Mapping): + """Base class for structure-to-structure mapping. + + {{ preamble }} + + {% if examples -%} + {% for example in examples -%} + {{ input_value_title }}: + {{ value_str(example.value.left) | indent(2, True) }} + + {%- if missing_type_dependencies(example.value) %} + + {{ type_definitions_title }}: + {{ type_definitions_str(example.value) | indent(2, True) }} + {%- endif %} + + {{ output_value_title }}: + {{ value_str(example.value.right) | indent(2, True) }} + + {% endfor %} + {% endif -%} + {{ input_value_title }}: + {{ value_str(input_value) | indent(2, True) }} + {%- if missing_type_dependencies(input_value) %} + + {{ type_definitions_title }}: + {{ type_definitions_str(input_value) | indent(2, True) }} + {%- endif %} + + {{ output_value_title }}: + """ + + default: Annotated[ + Any, + ( + 'The default value to use if mapping failed. ' + 'If unspecified, error will be raisen.' + ), + ] = lf.message_transform.RAISE_IF_HAS_ERROR + + preamble: Annotated[ + lf.LangFunc, + 'Preamble used for structure-to-structure mapping.', + ] + + type_definitions_title: Annotated[ + str, 'The section title for type definitions.' + ] = 'CLASS_DEFINITIONS' + + input_value_title: Annotated[str, 'The section title for input value.'] + output_value_title: Annotated[str, 'The section title for output value.'] + + def _on_bound(self): + super()._on_bound() + if self.examples: + for example in self.examples: + if not isinstance(example.value, Pair): + raise ValueError( + 'The value of example must be a `lf.structured.Pair` object. ' + f'Encountered: { example.value }.' + ) + + @property + def input_value(self) -> Any: + return schema_lib.mark_missing(self.message.result) + + def value_str(self, value: Any) -> str: + return schema_lib.value_repr('python').repr(value, compact=False) + + def missing_type_dependencies(self, value: Any) -> list[Type[Any]]: + value_specs = tuple( + [v.value_spec for v in schema_lib.Missing.find_missing(value).values()] + ) + return schema_lib.class_dependencies(value_specs, include_subclasses=True) + + def type_definitions_str(self, value: Any) -> str | None: + return schema_lib.class_definitions( + self.missing_type_dependencies(value), markdown=True + ) + + def _value_context(self): + classes = schema_lib.class_dependencies(self.input_value) + return {cls.__name__: cls for cls in classes} + + def transform_output(self, lm_output: lf.Message) -> lf.Message: + try: + result = schema_lib.value_repr('python').parse( + lm_output.text, additional_context=self._value_context() + ) + except Exception as e: # pylint: disable=broad-exception-caught + if self.default == lf.message_transform.RAISE_IF_HAS_ERROR: + raise MappingError( + 'Cannot parse message text into structured output. ' + f'Error={e}. Text={lm_output.text!r}.' + ) from e + result = self.default + lm_output.result = result + return lm_output diff --git a/langfun/core/structured/mapping_test.py b/langfun/core/structured/mapping_test.py index b0d9227f..8fd8229a 100644 --- a/langfun/core/structured/mapping_test.py +++ b/langfun/core/structured/mapping_test.py @@ -143,5 +143,33 @@ def test_serialization(self): ) +class Activity(pg.Object): + description: str + + +class Itinerary(pg.Object): + day: pg.typing.Int[1, None] + type: pg.typing.Enum['daytime', 'nighttime'] + activities: list[Activity] + hotel: pg.typing.Str['.*Hotel'] | None + + +class TripPlan(pg.Object): + place: str + itineraries: list[Itinerary] + + +class PairTest(unittest.TestCase): + + def test_partial(self): + p = mapping.Pair( + TripPlan.partial(place='San Francisco'), + TripPlan.partial(itineraries=[Itinerary.partial(day=1)]), + ) + self.assertEqual(p.left.itineraries, schema_lib.MISSING) + self.assertEqual(p.right.place, schema_lib.MISSING) + self.assertEqual(p.right.itineraries[0].activities, schema_lib.MISSING) + + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/structured/nl2structure.py b/langfun/core/structured/parsing.py similarity index 55% rename from langfun/core/structured/nl2structure.py rename to langfun/core/structured/parsing.py index 0d1b2c54..78a58a2a 100644 --- a/langfun/core/structured/nl2structure.py +++ b/langfun/core/structured/parsing.py @@ -13,7 +13,6 @@ # limitations under the License. """Natural language text to structured value.""" -import abc import inspect from typing import Annotated, Any, Callable, Literal, Type, Union @@ -23,130 +22,8 @@ import pyglove as pg -class NaturalLanguageToStructure(mapping.Mapping): - """LangFunc for converting natural language text to structured value. - - {{ preamble }} - - {% if examples -%} - {% for example in examples -%} - {%- if example.nl_context -%} - {{ nl_context_title}}: - {{ example.nl_context | indent(2, True)}} - - {% endif -%} - {%- if example.nl_text -%} - {{ nl_text_title }}: - {{ example.nl_text | indent(2, True) }} - - {% endif -%} - {{ schema_title }}: - {{ example.schema_str(protocol) | indent(2, True) }} - - {{ value_title }}: - {{ example.value_str(protocol) | indent(2, True) }} - - {% endfor %} - {% endif -%} - {% if nl_context -%} - {{ nl_context_title }}: - {{ nl_context | indent(2, True)}} - - {% endif -%} - {% if nl_text -%} - {{ nl_text_title }}: - {{ nl_text | indent(2, True) }} - - {% endif -%} - {{ schema_title }}: - {{ schema.schema_str(protocol) | indent(2, True) }} - - {{ value_title }}: - """ - - schema: pg.typing.Annotated[ - # Automatic conversion from annotation to schema. - schema_lib.schema_spec(), - 'A `lf.structured.Schema` that constrains the structured value.', - ] - - default: Annotated[ - Any, - ( - 'The default value to use if parsing failed. ' - 'If unspecified, error will be raisen.' - ), - ] = lf.message_transform.RAISE_IF_HAS_ERROR - - preamble: Annotated[ - lf.LangFunc, - 'Preamble used for natural language-to-structure mapping.', - ] - - nl_context_title: Annotated[ - str, - 'The section title for nl_context.' - ] = 'USER_REQUEST' - - nl_text_title: Annotated[ - str, - 'The section title for nl_text.' - ] = 'LM_RESPONSE' - - schema_title: Annotated[str, 'The section title for schema.'] - - value_title: Annotated[str, 'The section title for schema.'] - - protocol: Annotated[ - schema_lib.SchemaProtocol, - 'The protocol for representing the schema and value.', - ] - - @property - @abc.abstractmethod - def nl_context(self) -> str | None: - """Returns the natural language context for obtaining the response. - - Returns: - The natural language context (prompt) for obtaining the response (either - in natural language or directly to structured protocol). If None, - `nl_text` - must be provided. - """ - - @property - @abc.abstractmethod - def nl_text(self) -> str | None: - """Returns the natural language text to map. - - Returns: - The natural language text (in LM response) to map to object. If None, - the LM directly outputs structured protocol instead of natural language. - If None, `nl_context` must be provided. - """ - - def transform_output(self, lm_output: lf.Message) -> lf.Message: - try: - lm_output.result = self.schema.parse( - lm_output.text, protocol=self.protocol - ) - except Exception as e: # pylint: disable=broad-exception-caught - if self.default == lf.message_transform.RAISE_IF_HAS_ERROR: - raise mapping.MappingError( - 'Cannot parse message text into structured output. ' - f'Error={e}. Text={lm_output.text!r}.' - ) from e - lm_output.result = self.default - return lm_output - - -# -# Parse object. -# - - @lf.use_init_args(['schema', 'default', 'examples']) -class ParseStructure(NaturalLanguageToStructure): +class ParseStructure(mapping.NaturalLanguageToStructure): """Parse an object out from a natural language text.""" # Customize the source of the text to be mapped and its context. @@ -391,152 +268,3 @@ def _default_parsing_examples() -> list[mapping.MappingExample]: lf.MessageTransform.as_structured = as_structured - - -# -# QueryStructure -# - - -@lf.use_init_args(['schema', 'default', 'examples']) -class QueryStructure(NaturalLanguageToStructure): - """Query an object out from a natural language text.""" - - @property - def nl_context(self) -> str: - """Returns the user request.""" - return self.message.text - - @property - def nl_text(self) -> None: - """Returns the LM response.""" - return None - - -class QueryStructureJson(QueryStructure): - """Query a structured value using JSON as the protocol.""" - - preamble = """ - Please respond to {{ nl_context_title }} with {{ value_title}} according to {{ schema_title }}: - - INSTRUCTIONS: - 1. If the schema has `_type`, carry it over to the JSON output. - 2. If a field from the schema cannot be extracted from the response, use null as the JSON value. - """ - - protocol = 'json' - schema_title = 'SCHEMA' - value_title = 'JSON' - - -class QueryStructurePython(QueryStructure): - """Query a structured value using Python as the protocol.""" - - preamble = """ - Please respond to {{ nl_context_title }} with {{ value_title }} according to {{ schema_title }}. - """ - protocol = 'python' - schema_title = 'RESULT_TYPE' - value_title = 'RESULT_OBJECT' - - -def _query_structure_cls( - protocol: schema_lib.SchemaProtocol, -) -> Type[QueryStructure]: - if protocol == 'json': - return QueryStructureJson - elif protocol == 'python': - return QueryStructurePython - else: - raise ValueError(f'Unknown protocol: {protocol!r}.') - - -def _default_query_examples() -> list[mapping.MappingExample]: - return [ - mapping.MappingExample( - nl_context='Brief introduction of the U.S.A.', - schema=_Country, - value=_Country( - name='The United States of America', - continents=['North America'], - num_states=50, - neighbor_countries=[ - 'Canada', - 'Mexico', - 'Bahamas', - 'Cuba', - 'Russia', - ], - population=333000000, - capital='Washington, D.C', - president=None, - ), - ) - ] - - -def query( - prompt: Union[lf.Message, str], - schema: Union[ - schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any] - ], - default: Any = lf.message_transform.RAISE_IF_HAS_ERROR, - *, - examples: list[mapping.MappingExample] | None = None, - protocol: schema_lib.SchemaProtocol = 'python', - **kwargs, -) -> Any: - """Parse a natural langugage message based on schema. - - Examples: - - ``` - class FlightDuration: - hours: int - minutes: int - - class Flight(pg.Object): - airline: str - flight_number: str - departure_airport_code: str - arrival_airport_code: str - departure_time: str - arrival_time: str - duration: FlightDuration - stops: int - price: float - - prompt = ''' - Information about flight UA2631. - ''' - - r = lf.query(prompt, Flight) - assert isinstance(r, Flight) - assert r.airline == 'United Airlines' - assert r.departure_airport_code == 'SFO' - assert r.duration.hour = 7 - ``` - - Args: - prompt: A `lf.Message` object or a string as the natural language prompt. - schema: A `lf.transforms.ParsingSchema` object or equivalent annotations. - default: The default value if parsing failed. If not specified, error will - be raised. - examples: An optional list of fewshot examples for helping parsing. If None, - the default one-shot example will be added. - protocol: The protocol for schema/value representation. Applicable values - are 'json' and 'python'. By default `python` will be used. - **kwargs: Keyword arguments passed to the - `lf.structured.NaturalLanguageToStructureed` transform, e.g. `lm` for - specifying the language model for structured parsing. - - Returns: - The result based on the schema. - """ - if examples is None: - examples = _default_query_examples() - t = _query_structure_cls(protocol)( - schema, default=default, examples=examples, **kwargs - ) - message = lf.AIMessage.from_value(prompt) - return t.transform(message=message).result diff --git a/langfun/core/structured/nl2structure_test.py b/langfun/core/structured/parsing_test.py similarity index 57% rename from langfun/core/structured/nl2structure_test.py rename to langfun/core/structured/parsing_test.py index 8857fa10..1ea440ce 100644 --- a/langfun/core/structured/nl2structure_test.py +++ b/langfun/core/structured/parsing_test.py @@ -19,7 +19,7 @@ import langfun.core as lf from langfun.core.llms import fake from langfun.core.structured import mapping -from langfun.core.structured import nl2structure +from langfun.core.structured import parsing import pyglove as pg @@ -37,7 +37,7 @@ class Itinerary(pg.Object): class ParseStructurePythonTest(unittest.TestCase): def test_render_no_examples(self): - l = nl2structure.ParseStructurePython(int) + l = parsing.ParseStructurePython(int) m = lf.AIMessage('Bla bla bla 12 / 6 + 2 = 4.', result='12 / 6 + 2 = 4') m.source = lf.UserMessage('Compute 12 / 6 + 2.', tags=['lm-input']) @@ -61,7 +61,7 @@ def test_render_no_examples(self): ) def test_render_no_context(self): - l = nl2structure.ParseStructurePython(int) + l = parsing.ParseStructurePython(int) m = lf.AIMessage('Bla bla bla 12 / 6 + 2 = 4.', result='12 / 6 + 2 = 4') self.assertEqual( @@ -81,7 +81,7 @@ def test_render_no_context(self): ) def test_render(self): - l = nl2structure.ParseStructurePython( + l = parsing.ParseStructurePython( int, examples=[ mapping.MappingExample( @@ -201,7 +201,7 @@ def test_transform(self): ), override_attrs=True, ): - l = lf.LangFunc(lm_input) >> nl2structure.ParseStructurePython( + l = lf.LangFunc(lm_input) >> parsing.ParseStructurePython( [Itinerary], examples=[ mapping.MappingExample( @@ -248,13 +248,11 @@ def test_bad_transform(self): def test_parse(self): lm = fake.StaticSequence(['1']) + self.assertEqual(parsing.parse('the answer is 1', int, lm=lm), 1) self.assertEqual( - nl2structure.parse('the answer is 1', int, lm=lm), - 1 - ) - self.assertEqual( - nl2structure.parse( - 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm), + parsing.parse( + 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm + ), 1, ) @@ -262,7 +260,7 @@ def test_parse(self): class ParseStructureJsonTest(unittest.TestCase): def test_render_no_examples(self): - l = nl2structure.ParseStructureJson(int) + l = parsing.ParseStructureJson(int) m = lf.AIMessage('Bla bla bla 12 / 6 + 2 = 4.', result='12 / 6 + 2 = 4') m.source = lf.UserMessage('Compute 12 / 6 + 2.', tags=['lm-input']) @@ -289,7 +287,7 @@ def test_render_no_examples(self): ) def test_render_no_context(self): - l = nl2structure.ParseStructureJson(int) + l = parsing.ParseStructureJson(int) m = lf.AIMessage('Bla bla bla 12 / 6 + 2 = 4.', result='12 / 6 + 2 = 4') self.assertEqual( @@ -312,7 +310,7 @@ def test_render_no_context(self): ) def test_render(self): - l = nl2structure.ParseStructureJson( + l = parsing.ParseStructureJson( int, examples=[ mapping.MappingExample( @@ -468,7 +466,7 @@ def test_transform(self): ), override_attrs=True, ): - l = lf.LangFunc(lm_input) >> nl2structure.ParseStructureJson( + l = lf.LangFunc(lm_input) >> parsing.ParseStructureJson( [Itinerary], examples=[ mapping.MappingExample( @@ -513,370 +511,30 @@ def test_bad_transform(self): ): lf.LangFunc('Compute 1 + 2').as_structured(int)() - def test_parse(self): - lm = fake.StaticSequence(['{"result": 1}']) - self.assertEqual( - nl2structure.parse('the answer is 1', int, lm=lm, protocol='json'), - 1 - ) - self.assertEqual( - nl2structure.parse( - 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm, - protocol='json', - ), - 1, - ) - - -class QueryStructurePythonTest(unittest.TestCase): - - def test_render_no_examples(self): - l = nl2structure.QueryStructurePython(int) - m = lf.AIMessage('Compute 12 / 6 + 2.') - - self.assertEqual( - l.render(message=m).text, - inspect.cleandoc(""" - Please respond to USER_REQUEST with RESULT_OBJECT according to RESULT_TYPE. - - USER_REQUEST: - Compute 12 / 6 + 2. - - RESULT_TYPE: - int - - RESULT_OBJECT: - """), - ) - - def test_render(self): - l = nl2structure.QueryStructurePython( - int, - examples=[ - mapping.MappingExample('What is the answer of 1 plus 1?', None, 2), - mapping.MappingExample( - 'Compute the value of 3 + (2 * 6).', None, 15 - ), - ], - ) - self.assertEqual( - l.render(message=lf.AIMessage('Compute 12 / 6 + 2.')).text, - inspect.cleandoc(""" - Please respond to USER_REQUEST with RESULT_OBJECT according to RESULT_TYPE. - - USER_REQUEST: - What is the answer of 1 plus 1? - - RESULT_TYPE: - int - - RESULT_OBJECT: - ```python - 2 - ``` - - USER_REQUEST: - Compute the value of 3 + (2 * 6). - - RESULT_TYPE: - int - - RESULT_OBJECT: - ```python - 15 - ``` - - - USER_REQUEST: - Compute 12 / 6 + 2. - - RESULT_TYPE: - int - - RESULT_OBJECT: - """), - ) - - def test_transform(self): - lm_input = lf.UserMessage('3-day itineraries to San Francisco') - parse_structured_response = inspect.cleandoc( - """ - ```python - [ - Itinerary( - day=1, - type='daytime', - activities=[ - Activity(description='Arrive in San Francisco and check into your hotel.'), - Activity(description='Take a walk around Fisherman\\'s Wharf and have dinner at one of the many seafood restaurants.'), - Activity(description='Visit Pier 39 and see the sea lions.'), - ], - hotel=None), - Itinerary( - day=2, - type='daytime', - activities=[ - Activity(description='Take a ferry to Alcatraz Island and tour the infamous prison.'), - Activity(description='Take a walk across the Golden Gate Bridge.'), - Activity(description='Visit the Japanese Tea Garden in Golden Gate Park.'), - ], - hotel=None), - Itinerary( - day=3, - type='daytime', - activities=[ - Activity(description='Visit the de Young Museum and see the collection of American art.'), - Activity(description='Visit the San Francisco Museum of Modern Art.'), - Activity(description='Take a cable car ride.'), - ], - hotel=None), - ] - ``` - """) with lf.context( - lm=fake.StaticSequence( - [parse_structured_response], - ), + lm=fake.StaticSequence(['three', '`3`']), override_attrs=True, ): - l = nl2structure.QueryStructurePython( - [Itinerary], - examples=[ - mapping.MappingExample( - nl_context=inspect.cleandoc(""" - Find the alternatives of expressing \"feeling great\". - """), - schema={'expression': str, 'words': list[str]}, - value={ - 'expression': 'feeling great', - 'words': [ - 'Ecstatic', - 'Delighted', - 'Wonderful', - 'Enjoyable', - 'Fantastic', - ], - }, - ) - ], + # Test default. + self.assertIsNone( + lf.LangFunc('Compute 1 + 2').as_structured(int, default=None)().result ) - r = l(message=lm_input) - self.assertEqual(len(r.result), 3) - self.assertIsInstance(r.result[0], Itinerary) - self.assertEqual(len(r.result[0].activities), 3) - self.assertIsNone(r.result[0].hotel) - - def test_bad_transform(self): - with lf.context( - lm=fake.StaticSequence(['a2']), - override_attrs=True, - ): - with self.assertRaisesRegex( - mapping.MappingError, - 'Cannot parse message text into structured output', - ): - nl2structure.query('Compute 1 + 2', int) - - def test_query(self): - lm = fake.StaticSequence(['1']) - self.assertEqual(nl2structure.query('what is 1 + 0', int, lm=lm), 1) - - -class QueryStructureJsonTest(unittest.TestCase): - - def test_render_no_examples(self): - l = nl2structure.QueryStructureJson(int) - m = lf.AIMessage('Compute 12 / 6 + 2.') + def test_parse(self): + lm = fake.StaticSequence(['{"result": 1}']) self.assertEqual( - l.render(message=m).text, - inspect.cleandoc(""" - Please respond to USER_REQUEST with JSON according to SCHEMA: - - INSTRUCTIONS: - 1. If the schema has `_type`, carry it over to the JSON output. - 2. If a field from the schema cannot be extracted from the response, use null as the JSON value. - - USER_REQUEST: - Compute 12 / 6 + 2. - - SCHEMA: - {"result": int} - - JSON: - """), - ) - - def test_render(self): - l = nl2structure.QueryStructureJson( - int, - examples=[ - mapping.MappingExample('What is the answer of 1 plus 1?', None, 2), - mapping.MappingExample( - 'Compute the value of 3 + (2 * 6).', None, 15 - ), - ], + parsing.parse('the answer is 1', int, lm=lm, protocol='json'), 1 ) self.assertEqual( - l.render(message=lf.AIMessage('Compute 12 / 6 + 2.')).text, - inspect.cleandoc(""" - Please respond to USER_REQUEST with JSON according to SCHEMA: - - INSTRUCTIONS: - 1. If the schema has `_type`, carry it over to the JSON output. - 2. If a field from the schema cannot be extracted from the response, use null as the JSON value. - - USER_REQUEST: - What is the answer of 1 plus 1? - - SCHEMA: - {"result": int} - - JSON: - {"result": 2} - - USER_REQUEST: - Compute the value of 3 + (2 * 6). - - SCHEMA: - {"result": int} - - JSON: - {"result": 15} - - - USER_REQUEST: - Compute 12 / 6 + 2. - - SCHEMA: - {"result": int} - - JSON: - """), - ) - - def test_transform(self): - lm_input = lf.UserMessage('3-day itineraries to San Francisco') - parse_structured_response = ( - lf.LangFunc( - """ - {"result": [ - { - "_type": {{itinerary_type}}, - "day": 1, - "type": "daytime", - "activities": [ - { - "_type": {{activity_type}}, - "description": "Arrive in San Francisco and check into your hotel." - }, - { - "_type": {{activity_type}}, - "description": "Take a walk around Fisherman's Wharf and have dinner at one of the many seafood restaurants." - }, - { - "_type": {{activity_type}}, - "description": "Visit Pier 39 and see the sea lions." - } - ], - "hotel": null - }, - { - "_type": {{itinerary_type}}, - "day": 2, - "type": "daytime", - "activities": [ - { - "_type": {{activity_type}}, - "description": "Take a ferry to Alcatraz Island and tour the infamous prison." - }, - { - "_type": {{activity_type}}, - "description": "Take a walk across the Golden Gate Bridge." - }, - { - "_type": {{activity_type}}, - "description": "Visit the Japanese Tea Garden in Golden Gate Park." - } - ], - "hotel": null - }, - { - "_type": {{itinerary_type}}, - "day": 3, - "type": "daytime", - "activities": [ - { - "_type": {{activity_type}}, - "description": "Visit the de Young Museum and see the collection of American art." - }, - { - "_type": {{activity_type}}, - "description": "Visit the San Francisco Museum of Modern Art." - }, - { - "_type": {{activity_type}}, - "description": "Take a cable car ride." - } - ], - "hotel": null - } - ]} - """, - itinerary_type=f'"{Itinerary.__type_name__}"', - activity_type=f'"{Activity.__type_name__}"', - ) - .render() - .text - ) - with lf.context( - lm=fake.StaticSequence( - [parse_structured_response], + parsing.parse( + 'the answer is 1', + int, + user_prompt='what is 0 + 1?', + lm=lm, + protocol='json', ), - override_attrs=True, - ): - l = nl2structure.QueryStructureJson( - [Itinerary], - examples=[ - mapping.MappingExample( - nl_context=inspect.cleandoc(""" - Find the alternatives of expressing \"feeling great\". - """), - schema={'expression': str, 'words': list[str]}, - value={ - 'expression': 'feeling great', - 'words': [ - 'Ecstatic', - 'Delighted', - 'Wonderful', - 'Enjoyable', - 'Fantastic', - ], - }, - ) - ], - ) - r = l(message=lm_input) - self.assertEqual(len(r.result), 3) - self.assertIsInstance(r.result[0], Itinerary) - self.assertEqual(len(r.result[0].activities), 3) - self.assertIsNone(r.result[0].hotel) - - def test_bad_transform(self): - with lf.context( - lm=fake.StaticSequence(['3']), - override_attrs=True, - ): - with self.assertRaisesRegex( - mapping.MappingError, - 'Cannot parse message text into structured output', - ): - nl2structure.query('Compute 1 + 2', int, protocol='json') - - def test_query(self): - lm = fake.StaticSequence(['{"result": 1}']) - self.assertEqual( - nl2structure.query('what is 1 + 0', int, lm=lm, protocol='json'), 1) + 1, + ) if __name__ == '__main__': diff --git a/langfun/core/structured/query.py b/langfun/core/structured/query.py new file mode 100644 index 00000000..84c363ae --- /dev/null +++ b/langfun/core/structured/query.py @@ -0,0 +1,183 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Natural language text to structured value.""" + +from typing import Any, Literal, Type, Union + +import langfun.core as lf +from langfun.core.structured import mapping +from langfun.core.structured import schema as schema_lib +import pyglove as pg + + +@lf.use_init_args(['schema', 'default', 'examples']) +class QueryStructure(mapping.NaturalLanguageToStructure): + """Query an object out from a natural language text.""" + + @property + def nl_context(self) -> str: + """Returns the user request.""" + return self.message.text + + @property + def nl_text(self) -> None: + """Returns the LM response.""" + return None + + +class QueryStructureJson(QueryStructure): + """Query a structured value using JSON as the protocol.""" + + preamble = """ + Please respond to {{ nl_context_title }} with {{ value_title}} according to {{ schema_title }}: + + INSTRUCTIONS: + 1. If the schema has `_type`, carry it over to the JSON output. + 2. If a field from the schema cannot be extracted from the response, use null as the JSON value. + """ + + protocol = 'json' + schema_title = 'SCHEMA' + value_title = 'JSON' + + +class QueryStructurePython(QueryStructure): + """Query a structured value using Python as the protocol.""" + + preamble = """ + Please respond to {{ nl_context_title }} with {{ value_title }} according to {{ schema_title }}. + """ + protocol = 'python' + schema_title = 'RESULT_TYPE' + value_title = 'RESULT_OBJECT' + + +def _query_structure_cls( + protocol: schema_lib.SchemaProtocol, +) -> Type[QueryStructure]: + if protocol == 'json': + return QueryStructureJson + elif protocol == 'python': + return QueryStructurePython + else: + raise ValueError(f'Unknown protocol: {protocol!r}.') + + +class _Country(pg.Object): + """A example dataclass for structured parsing.""" + name: str + continents: list[Literal[ + 'Africa', + 'Asia', + 'Europe', + 'Oceania', + 'North America', + 'South America' + ]] + num_states: int + neighbor_countries: list[str] + population: int + capital: str | None + president: str | None + + +def _default_query_examples() -> list[mapping.MappingExample]: + return [ + mapping.MappingExample( + nl_context='Brief introduction of the U.S.A.', + schema=_Country, + value=_Country( + name='The United States of America', + continents=['North America'], + num_states=50, + neighbor_countries=[ + 'Canada', + 'Mexico', + 'Bahamas', + 'Cuba', + 'Russia', + ], + population=333000000, + capital='Washington, D.C', + president=None, + ), + ) + ] + + +def query( + prompt: Union[lf.Message, str], + schema: Union[ + schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any] + ], + default: Any = lf.message_transform.RAISE_IF_HAS_ERROR, + *, + examples: list[mapping.MappingExample] | None = None, + protocol: schema_lib.SchemaProtocol = 'python', + **kwargs, +) -> Any: + """Parse a natural langugage message based on schema. + + Examples: + + ``` + class FlightDuration: + hours: int + minutes: int + + class Flight(pg.Object): + airline: str + flight_number: str + departure_airport_code: str + arrival_airport_code: str + departure_time: str + arrival_time: str + duration: FlightDuration + stops: int + price: float + + prompt = ''' + Information about flight UA2631. + ''' + + r = lf.query(prompt, Flight) + assert isinstance(r, Flight) + assert r.airline == 'United Airlines' + assert r.departure_airport_code == 'SFO' + assert r.duration.hour = 7 + ``` + + Args: + prompt: A `lf.Message` object or a string as the natural language prompt. + schema: A `lf.transforms.ParsingSchema` object or equivalent annotations. + default: The default value if parsing failed. If not specified, error will + be raised. + examples: An optional list of fewshot examples for helping parsing. If None, + the default one-shot example will be added. + protocol: The protocol for schema/value representation. Applicable values + are 'json' and 'python'. By default `python` will be used. + **kwargs: Keyword arguments passed to the + `lf.structured.NaturalLanguageToStructureed` transform, e.g. `lm` for + specifying the language model for structured parsing. + + Returns: + The result based on the schema. + """ + if examples is None: + examples = _default_query_examples() + t = _query_structure_cls(protocol)( + schema, default=default, examples=examples, **kwargs + ) + message = lf.AIMessage.from_value(prompt) + return t.transform(message=message).result diff --git a/langfun/core/structured/query_test.py b/langfun/core/structured/query_test.py new file mode 100644 index 00000000..2a4c4f0d --- /dev/null +++ b/langfun/core/structured/query_test.py @@ -0,0 +1,390 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for structured query_lib.""" + +import inspect +import unittest + +import langfun.core as lf +from langfun.core.llms import fake +from langfun.core.structured import mapping +from langfun.core.structured import query as query_lib +import pyglove as pg + + +class Activity(pg.Object): + description: str + + +class Itinerary(pg.Object): + day: pg.typing.Int[1, None] + type: pg.typing.Enum['daytime', 'nighttime'] + activities: list[Activity] + hotel: pg.typing.Str['.*Hotel'] | None + + +class QueryStructurePythonTest(unittest.TestCase): + + def test_render_no_examples(self): + l = query_lib.QueryStructurePython(int) + m = lf.AIMessage('Compute 12 / 6 + 2.') + + self.assertEqual( + l.render(message=m).text, + inspect.cleandoc(""" + Please respond to USER_REQUEST with RESULT_OBJECT according to RESULT_TYPE. + + USER_REQUEST: + Compute 12 / 6 + 2. + + RESULT_TYPE: + int + + RESULT_OBJECT: + """), + ) + + def test_render(self): + l = query_lib.QueryStructurePython( + int, + examples=[ + mapping.MappingExample('What is the answer of 1 plus 1?', None, 2), + mapping.MappingExample( + 'Compute the value of 3 + (2 * 6).', None, 15 + ), + ], + ) + self.assertEqual( + l.render(message=lf.AIMessage('Compute 12 / 6 + 2.')).text, + inspect.cleandoc(""" + Please respond to USER_REQUEST with RESULT_OBJECT according to RESULT_TYPE. + + USER_REQUEST: + What is the answer of 1 plus 1? + + RESULT_TYPE: + int + + RESULT_OBJECT: + ```python + 2 + ``` + + USER_REQUEST: + Compute the value of 3 + (2 * 6). + + RESULT_TYPE: + int + + RESULT_OBJECT: + ```python + 15 + ``` + + + USER_REQUEST: + Compute 12 / 6 + 2. + + RESULT_TYPE: + int + + RESULT_OBJECT: + """), + ) + + def test_transform(self): + lm_input = lf.UserMessage('3-day itineraries to San Francisco') + parse_structured_response = inspect.cleandoc( + """ + ```python + [ + Itinerary( + day=1, + type='daytime', + activities=[ + Activity(description='Arrive in San Francisco and check into your hotel.'), + Activity(description='Take a walk around Fisherman\\'s Wharf and have dinner at one of the many seafood restaurants.'), + Activity(description='Visit Pier 39 and see the sea lions.'), + ], + hotel=None), + Itinerary( + day=2, + type='daytime', + activities=[ + Activity(description='Take a ferry to Alcatraz Island and tour the infamous prison.'), + Activity(description='Take a walk across the Golden Gate Bridge.'), + Activity(description='Visit the Japanese Tea Garden in Golden Gate Park.'), + ], + hotel=None), + Itinerary( + day=3, + type='daytime', + activities=[ + Activity(description='Visit the de Young Museum and see the collection of American art.'), + Activity(description='Visit the San Francisco Museum of Modern Art.'), + Activity(description='Take a cable car ride.'), + ], + hotel=None), + ] + ``` + """) + with lf.context( + lm=fake.StaticSequence( + [parse_structured_response], + ), + override_attrs=True, + ): + l = query_lib.QueryStructurePython( + [Itinerary], + examples=[ + mapping.MappingExample( + nl_context=inspect.cleandoc(""" + Find the alternatives of expressing \"feeling great\". + """), + schema={'expression': str, 'words': list[str]}, + value={ + 'expression': 'feeling great', + 'words': [ + 'Ecstatic', + 'Delighted', + 'Wonderful', + 'Enjoyable', + 'Fantastic', + ], + }, + ) + ], + ) + r = l(message=lm_input) + self.assertEqual(len(r.result), 3) + self.assertIsInstance(r.result[0], Itinerary) + self.assertEqual(len(r.result[0].activities), 3) + self.assertIsNone(r.result[0].hotel) + + def test_bad_transform(self): + with lf.context( + lm=fake.StaticSequence(['a2']), + override_attrs=True, + ): + with self.assertRaisesRegex( + mapping.MappingError, + 'Cannot parse message text into structured output', + ): + query_lib.query('Compute 1 + 2', int) + + def test_query(self): + lm = fake.StaticSequence(['1']) + self.assertEqual(query_lib.query('what is 1 + 0', int, lm=lm), 1) + + +class QueryStructureJsonTest(unittest.TestCase): + + def test_render_no_examples(self): + l = query_lib.QueryStructureJson(int) + m = lf.AIMessage('Compute 12 / 6 + 2.') + + self.assertEqual( + l.render(message=m).text, + inspect.cleandoc(""" + Please respond to USER_REQUEST with JSON according to SCHEMA: + + INSTRUCTIONS: + 1. If the schema has `_type`, carry it over to the JSON output. + 2. If a field from the schema cannot be extracted from the response, use null as the JSON value. + + USER_REQUEST: + Compute 12 / 6 + 2. + + SCHEMA: + {"result": int} + + JSON: + """), + ) + + def test_render(self): + l = query_lib.QueryStructureJson( + int, + examples=[ + mapping.MappingExample('What is the answer of 1 plus 1?', None, 2), + mapping.MappingExample( + 'Compute the value of 3 + (2 * 6).', None, 15 + ), + ], + ) + self.assertEqual( + l.render(message=lf.AIMessage('Compute 12 / 6 + 2.')).text, + inspect.cleandoc(""" + Please respond to USER_REQUEST with JSON according to SCHEMA: + + INSTRUCTIONS: + 1. If the schema has `_type`, carry it over to the JSON output. + 2. If a field from the schema cannot be extracted from the response, use null as the JSON value. + + USER_REQUEST: + What is the answer of 1 plus 1? + + SCHEMA: + {"result": int} + + JSON: + {"result": 2} + + USER_REQUEST: + Compute the value of 3 + (2 * 6). + + SCHEMA: + {"result": int} + + JSON: + {"result": 15} + + + USER_REQUEST: + Compute 12 / 6 + 2. + + SCHEMA: + {"result": int} + + JSON: + """), + ) + + def test_transform(self): + lm_input = lf.UserMessage('3-day itineraries to San Francisco') + parse_structured_response = ( + lf.LangFunc( + """ + {"result": [ + { + "_type": {{itinerary_type}}, + "day": 1, + "type": "daytime", + "activities": [ + { + "_type": {{activity_type}}, + "description": "Arrive in San Francisco and check into your hotel." + }, + { + "_type": {{activity_type}}, + "description": "Take a walk around Fisherman's Wharf and have dinner at one of the many seafood restaurants." + }, + { + "_type": {{activity_type}}, + "description": "Visit Pier 39 and see the sea lions." + } + ], + "hotel": null + }, + { + "_type": {{itinerary_type}}, + "day": 2, + "type": "daytime", + "activities": [ + { + "_type": {{activity_type}}, + "description": "Take a ferry to Alcatraz Island and tour the infamous prison." + }, + { + "_type": {{activity_type}}, + "description": "Take a walk across the Golden Gate Bridge." + }, + { + "_type": {{activity_type}}, + "description": "Visit the Japanese Tea Garden in Golden Gate Park." + } + ], + "hotel": null + }, + { + "_type": {{itinerary_type}}, + "day": 3, + "type": "daytime", + "activities": [ + { + "_type": {{activity_type}}, + "description": "Visit the de Young Museum and see the collection of American art." + }, + { + "_type": {{activity_type}}, + "description": "Visit the San Francisco Museum of Modern Art." + }, + { + "_type": {{activity_type}}, + "description": "Take a cable car ride." + } + ], + "hotel": null + } + ]} + """, + itinerary_type=f'"{Itinerary.__type_name__}"', + activity_type=f'"{Activity.__type_name__}"', + ) + .render() + .text + ) + with lf.context( + lm=fake.StaticSequence( + [parse_structured_response], + ), + override_attrs=True, + ): + l = query_lib.QueryStructureJson( + [Itinerary], + examples=[ + mapping.MappingExample( + nl_context=inspect.cleandoc(""" + Find the alternatives of expressing \"feeling great\". + """), + schema={'expression': str, 'words': list[str]}, + value={ + 'expression': 'feeling great', + 'words': [ + 'Ecstatic', + 'Delighted', + 'Wonderful', + 'Enjoyable', + 'Fantastic', + ], + }, + ) + ], + ) + r = l(message=lm_input) + self.assertEqual(len(r.result), 3) + self.assertIsInstance(r.result[0], Itinerary) + self.assertEqual(len(r.result[0].activities), 3) + self.assertIsNone(r.result[0].hotel) + + def test_bad_transform(self): + with lf.context( + lm=fake.StaticSequence(['3']), + override_attrs=True, + ): + with self.assertRaisesRegex( + mapping.MappingError, + 'Cannot parse message text into structured output', + ): + query_lib.query('Compute 1 + 2', int, protocol='json') + + def test_query(self): + lm = fake.StaticSequence(['{"result": 1}']) + self.assertEqual( + query_lib.query('what is 1 + 0', int, lm=lm, protocol='json'), 1 + ) + + +if __name__ == '__main__': + unittest.main()