Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure lf.structured files. #17

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,28 @@
from langfun.core.structured.mapping import MappingExample
from langfun.core.structured.mapping import MappingError

from langfun.core.structured.nl2structure import NaturalLanguageToStructure
from langfun.core.structured.nl2structure import ParseStructure
from langfun.core.structured.nl2structure import ParseStructureJson
from langfun.core.structured.nl2structure import ParseStructurePython
from langfun.core.structured.nl2structure import QueryStructure
from langfun.core.structured.nl2structure import QueryStructureJson
from langfun.core.structured.nl2structure import QueryStructurePython
from langfun.core.structured.nl2structure import parse
from langfun.core.structured.nl2structure import query
# Mappings of between different forms of content.
from langfun.core.structured.mapping import NaturalLanguageToStructure
from langfun.core.structured.mapping import StructureToNaturalLanguage
from langfun.core.structured.mapping import StructureToStructure

from langfun.core.structured.structure2nl import StructureToNaturalLanguage
from langfun.core.structured.structure2nl import DescribeStructure
from langfun.core.structured.structure2nl import describe
from langfun.core.structured.parsing import ParseStructure
from langfun.core.structured.parsing import ParseStructureJson
from langfun.core.structured.parsing import ParseStructurePython
from langfun.core.structured.parsing import parse

from langfun.core.structured.structure2structure import StructureToStructure
from langfun.core.structured.structure2structure import CompleteStructure
from langfun.core.structured.structure2structure import complete
import langfun.core.structured.query as query_lib

from langfun.core.structured.query import QueryStructure
from langfun.core.structured.query import QueryStructureJson
from langfun.core.structured.query import QueryStructurePython
from langfun.core.structured.query import query

from langfun.core.structured.description import DescribeStructure
from langfun.core.structured.description import describe

from langfun.core.structured.completion import CompleteStructure
from langfun.core.structured.completion import complete


# pylint: enable=g-importing-member
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,126 +13,15 @@
# limitations under the License.
"""Structure-to-structure mappings."""

from typing import Annotated, Any, Literal, Type
from typing import Any, Literal

import langfun.core as lf
from langfun.core.structured import mapping
from langfun.core.structured import schema as schema_lib
import pyglove as pg


class Pair(pg.Object):
"""Value pair used for expressing structure-to-structure mapping."""

left: pg.typing.Annotated[
pg.typing.Any(transform=schema_lib.mark_missing), 'The left-side value.'
]
right: pg.typing.Annotated[
pg.typing.Any(transform=schema_lib.mark_missing), 'The right-side value.'
]


class StructureToStructure(mapping.Mapping):
"""Base class for structure-to-structure mapping.

{{ preamble }}

{% if examples -%}
{% for example in examples -%}
{{ input_value_title }}:
{{ value_str(example.value.left) | indent(2, True) }}

{%- if missing_type_dependencies(example.value) %}

{{ type_definitions_title }}:
{{ type_definitions_str(example.value) | indent(2, True) }}
{%- endif %}

{{ output_value_title }}:
{{ value_str(example.value.right) | indent(2, True) }}

{% endfor %}
{% endif -%}
{{ input_value_title }}:
{{ value_str(input_value) | indent(2, True) }}
{%- if missing_type_dependencies(input_value) %}

{{ type_definitions_title }}:
{{ type_definitions_str(input_value) | indent(2, True) }}
{%- endif %}

{{ output_value_title }}:
"""

default: Annotated[
Any,
(
'The default value to use if mapping failed. '
'If unspecified, error will be raisen.'
),
] = lf.message_transform.RAISE_IF_HAS_ERROR

preamble: Annotated[
lf.LangFunc,
'Preamble used for structure-to-structure mapping.',
]

type_definitions_title: Annotated[
str, 'The section title for type definitions.'
] = 'CLASS_DEFINITIONS'

input_value_title: Annotated[str, 'The section title for input value.']
output_value_title: Annotated[str, 'The section title for output value.']

def _on_bound(self):
super()._on_bound()
if self.examples:
for example in self.examples:
if not isinstance(example.value, Pair):
raise ValueError(
'The value of example must be a `lf.structured.Pair` object. '
f'Encountered: { example.value }.'
)

@property
def input_value(self) -> Any:
return schema_lib.mark_missing(self.message.result)

def value_str(self, value: Any) -> str:
return schema_lib.value_repr('python').repr(value, compact=False)

def missing_type_dependencies(self, value: Any) -> list[Type[Any]]:
value_specs = tuple(
[v.value_spec for v in schema_lib.Missing.find_missing(value).values()]
)
return schema_lib.class_dependencies(value_specs, include_subclasses=True)

def type_definitions_str(self, value: Any) -> str | None:
return schema_lib.class_definitions(
self.missing_type_dependencies(value), markdown=True
)

def _value_context(self):
classes = schema_lib.class_dependencies(self.input_value)
return {cls.__name__: cls for cls in classes}

def transform_output(self, lm_output: lf.Message) -> lf.Message:
try:
result = schema_lib.value_repr('python').parse(
lm_output.text, additional_context=self._value_context()
)
except Exception as e: # pylint: disable=broad-exception-caught
if self.default == lf.message_transform.RAISE_IF_HAS_ERROR:
raise mapping.MappingError(
'Cannot parse message text into structured output. '
f'Error={e}. Text={lm_output.text!r}.'
) from e
result = self.default
lm_output.result = result
return lm_output


class CompleteStructure(StructureToStructure):
class CompleteStructure(mapping.StructureToStructure):
"""Complete structure by filling the missing fields."""

preamble = lf.LangFunc("""
Expand Down Expand Up @@ -173,7 +62,7 @@ class _Country(pg.Object):
def _default_complete_examples() -> list[mapping.MappingExample]:
return [
mapping.MappingExample(
value=Pair(
value=mapping.Pair(
left=_Country.partial(name='United States of America'),
right=_Country(
name='United States of America',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for langfun.core.structured.structure2structure."""
"""Tests for langfun.core.structured.completion."""

import inspect
import unittest

import langfun.core as lf
from langfun.core.llms import fake
from langfun.core.structured import completion
from langfun.core.structured import mapping
from langfun.core.structured import schema as schema_lib
from langfun.core.structured import structure2structure
import pyglove as pg


Expand All @@ -39,22 +39,10 @@ class TripPlan(pg.Object):
itineraries: list[Itinerary]


class PairTest(unittest.TestCase):

def test_partial(self):
p = structure2structure.Pair(
TripPlan.partial(place='San Francisco'),
TripPlan.partial(itineraries=[Itinerary.partial(day=1)]),
)
self.assertEqual(p.left.itineraries, schema_lib.MISSING)
self.assertEqual(p.right.place, schema_lib.MISSING)
self.assertEqual(p.right.itineraries[0].activities, schema_lib.MISSING)


class CompleteStructureTest(unittest.TestCase):

def test_render_no_examples(self):
l = structure2structure.CompleteStructure()
l = completion.CompleteStructure()
m = lf.UserMessage(
'',
result=TripPlan.partial(
Expand Down Expand Up @@ -115,7 +103,7 @@ class Activity:
)

def test_render_no_class_definitions(self):
l = structure2structure.CompleteStructure()
l = completion.CompleteStructure()
m = lf.UserMessage(
'',
result=TripPlan.partial(
Expand Down Expand Up @@ -182,8 +170,8 @@ def test_render_no_class_definitions(self):
)

def test_render_with_examples(self):
l = structure2structure.CompleteStructure(
examples=structure2structure._default_complete_examples()
l = completion.CompleteStructure(
examples=completion._default_complete_examples()
)
m = lf.UserMessage(
'',
Expand Down Expand Up @@ -323,7 +311,7 @@ def test_transform(self):
),
override_attrs=True,
):
r = structure2structure.complete(
r = completion.complete(
TripPlan.partial(
place='San Francisco',
itineraries=[
Expand Down Expand Up @@ -413,9 +401,7 @@ def test_transform(self):

def test_bad_init(self):
with self.assertRaisesRegex(ValueError, '.*must be.*Pair'):
structure2structure.CompleteStructure(
examples=[mapping.MappingExample(value=1)]
)
completion.CompleteStructure(examples=[mapping.MappingExample(value=1)])

def test_bad_transform(self):
with lf.context(
Expand All @@ -426,14 +412,14 @@ def test_bad_transform(self):
mapping.MappingError,
'Cannot parse message text into structured output',
):
structure2structure.complete(Activity.partial())
completion.complete(Activity.partial())

def test_default(self):
with lf.context(
lm=fake.StaticSequence(['Activity(description=1)']),
override_attrs=True,
):
self.assertIsNone(structure2structure.complete(Activity.partial(), None))
self.assertIsNone(completion.complete(Activity.partial(), None))


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,76 +14,15 @@
"""Structured value to natural language."""

import inspect
from typing import Annotated, Any, Literal
from typing import Any, Literal

import langfun.core as lf
from langfun.core.structured import mapping
from langfun.core.structured import schema as schema_lib
import pyglove as pg


class StructureToNaturalLanguage(mapping.Mapping):
"""LangFunc for converting a structured value to natural language.

{{ preamble }}

{% if examples -%}
{% for example in examples -%}
{%- if example.nl_context -%}
{{ nl_context_title}}:
{{ example.nl_context | indent(2, True)}}

{% endif -%}
{{ value_title}}:
{{ value_str(example.value) | indent(2, True) }}

{{ nl_text_title }}:
{{ example.nl_text | indent(2, True) }}

{% endfor %}
{% endif -%}
{% if nl_context -%}
{{ nl_context_title }}:
{{ nl_context | indent(2, True)}}

{% endif -%}
{{ value_title }}:
{{ value_str(value) | indent(2, True) }}

{{ nl_text_title }}:
"""

preamble: Annotated[
lf.LangFunc, 'Preamble used for zeroshot natural language mapping.'
]

nl_context_title: Annotated[str, 'The section title for nl_context.'] = (
'CONTEXT_FOR_DESCRIPTION'
)

nl_text_title: Annotated[str, 'The section title for nl_text.'] = (
'NATURAL_LANGUAGE_TEXT'
)

value_title: Annotated[str, 'The section title for schema.'] = 'PYTHON_OBJECT'

@property
def value(self) -> Any:
"""Returns the structured input value."""
return self.message.result

@property
def nl_context(self) -> str:
"""Returns the context information for the description."""
return self.message.text

def value_str(self, value: Any) -> str:
return schema_lib.value_repr('python').repr(
value, markdown=False, compact=False)


@pg.use_init_args(['examples'])
class DescribeStructure(StructureToNaturalLanguage):
class DescribeStructure(mapping.StructureToNaturalLanguage):
"""Describe a structured value in natural language."""

preamble = """
Expand Down
Loading
Loading