Skip to content

Commit

Permalink
Clean up lf.structured.
Browse files Browse the repository at this point in the history
1) Rename prompting.py to query.py as it's easier for users to locate the source code of `lf.query`.
2) Make helper classes without backward compatibility guarentee private.

PiperOrigin-RevId: 705976528
  • Loading branch information
daiyip authored and langfun authors committed Dec 13, 2024
1 parent 8585231 commit be31028
Show file tree
Hide file tree
Showing 16 changed files with 191 additions and 204 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks/langfun101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/langfun/blob/main/docs/notebooks/langfun101.ipynb)\n",
"\n",
"Effective programming of Large Language Models (LLMs) demands a seamless integration of natural language text with structured data. Langfun, leveraging [PyGlove](https://github.com/google/pyglove)'s symbolic objects, provides a simple yet powerful interface for mapping between Python objects with the assistance of LLMs. The input/output objects may include natural language texts (in string form), structured data (objects of a specific class), modalities (such as images), and more. The unified API for accomplishing all conceivable mappings is [`lf.query`](https://github.com/google/langfun/blob/145f4a48603d2e6e19f9025e032ac3c86dfd3e35/langfun/core/structured/prompting.py#L103).\n"
"Effective programming of Large Language Models (LLMs) demands a seamless integration of natural language text with structured data. Langfun, leveraging [PyGlove](https://github.com/google/pyglove)'s symbolic objects, provides a simple yet powerful interface for mapping between Python objects with the assistance of LLMs. The input/output objects may include natural language texts (in string form), structured data (objects of a specific class), modalities (such as images), and more. The unified API for accomplishing all conceivable mappings is [`lf.query`](https://github.com/google/langfun/blob/145f4a48603d2e6e19f9025e032ac3c86dfd3e35/langfun/core/structured/query.py#L103).\n"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/coding/python/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def run_with_correction(
# Delay import at runtime to avoid circular depenency.
# pylint: disable=g-import-not-at-top
# pytype: disable=import-error
from langfun.core.structured import prompting
from langfun.core.structured import query as query_lib
# pytype: enable=import-error
# pylint: enable=g-import-not-at-top

Expand Down Expand Up @@ -119,7 +119,7 @@ def result_and_error(code: str) -> tuple[Any, str | None]:
# structure.
try:
# Disable autofix for code correction to avoid recursion.
correction = prompting.query(
correction = query_lib.query(
CodeWithError(code=code, error=error), CorrectedCode, lm=lm, autofix=0
)
except errors.CodeError:
Expand Down
30 changes: 8 additions & 22 deletions langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@
from langfun.core.structured.schema import annotation
from langfun.core.structured.schema import structure_from_python

from langfun.core.structured.schema import SchemaRepr
from langfun.core.structured.schema import SchemaJsonRepr
from langfun.core.structured.schema import SchemaPythonRepr
from langfun.core.structured.schema import ValueRepr
from langfun.core.structured.schema import ValueJsonRepr
from langfun.core.structured.schema import ValuePythonRepr
from langfun.core.structured.schema import schema_repr
from langfun.core.structured.schema import source_form
from langfun.core.structured.schema import value_repr
Expand All @@ -56,26 +50,18 @@
from langfun.core.structured.mapping import MappingError
from langfun.core.structured.mapping import MappingExample

from langfun.core.structured.parsing import ParseStructure
from langfun.core.structured.parsing import ParseStructureJson
from langfun.core.structured.parsing import ParseStructurePython
from langfun.core.structured.parsing import parse
from langfun.core.structured.parsing import call

from langfun.core.structured.prompting import QueryStructure
from langfun.core.structured.prompting import QueryStructureJson
from langfun.core.structured.prompting import QueryStructurePython
from langfun.core.structured.prompting import query
from langfun.core.structured.prompting import query_prompt
from langfun.core.structured.prompting import query_output
from langfun.core.structured.prompting import query_reward
from langfun.core.structured.prompting import QueryInvocation
from langfun.core.structured.prompting import track_queries

from langfun.core.structured.description import DescribeStructure
from langfun.core.structured.description import describe
from langfun.core.structured.query import query
from langfun.core.structured.query import query_prompt
from langfun.core.structured.query import query_output
from langfun.core.structured.query import query_reward
from langfun.core.structured.query import QueryInvocation
from langfun.core.structured.query import track_queries
from langfun.core.structured import query as query_lib

from langfun.core.structured.completion import CompleteStructure
from langfun.core.structured.description import describe
from langfun.core.structured.completion import complete

from langfun.core.structured.scoring import score
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/structured/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pyglove as pg


class CompleteStructure(mapping.Mapping):
class _CompleteStructure(mapping.Mapping):
"""Complete structure by filling the missing fields."""

input: Annotated[
Expand Down Expand Up @@ -241,7 +241,7 @@ class Flight(pg.Object):
Returns:
The result based on the schema.
"""
t = CompleteStructure(
t = _CompleteStructure(
input=schema_lib.mark_missing(input_value),
default=default,
examples=examples,
Expand Down
8 changes: 4 additions & 4 deletions langfun/core/structured/completion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TripPlan(pg.Object):
class CompleteStructureTest(unittest.TestCase):

def test_render_no_examples(self):
l = completion.CompleteStructure()
l = completion._CompleteStructure()
input_value = schema_lib.mark_missing(
TripPlan.partial(
place='San Francisco',
Expand Down Expand Up @@ -120,7 +120,7 @@ class Activity:
)

def test_render_no_class_definitions(self):
l = completion.CompleteStructure()
l = completion._CompleteStructure()
input_value = schema_lib.mark_missing(
TripPlan.partial(
place='San Francisco',
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_render_no_class_definitions(self):
)

def test_render_with_examples(self):
l = completion.CompleteStructure()
l = completion._CompleteStructure()
input_value = schema_lib.mark_missing(
TripPlan.partial(
place='San Francisco',
Expand Down Expand Up @@ -411,7 +411,7 @@ class Animal(pg.Object):
modalities.Image.from_bytes(b'image_of_elephant'),
)
)
l = completion.CompleteStructure(
l = completion._CompleteStructure(
input=input_value,
examples=[
mapping.MappingExample(
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/structured/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@pg.use_init_args(['examples'])
class DescribeStructure(mapping.Mapping):
class _DescribeStructure(mapping.Mapping):
"""Describe a structured value in natural language."""

input_title = 'PYTHON_OBJECT'
Expand Down Expand Up @@ -106,7 +106,7 @@ class Flight(pg.Object):
Returns:
The parsed result based on the schema.
"""
return DescribeStructure(
return _DescribeStructure(
input=value,
context=context,
examples=examples or default_describe_examples(),
Expand Down
6 changes: 3 additions & 3 deletions langfun/core/structured/description_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Itinerary(pg.Object):
class DescribeStructureTest(unittest.TestCase):

def test_render(self):
l = description_lib.DescribeStructure(
l = description_lib._DescribeStructure(
input=Itinerary(
day=1,
type='daytime',
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_render_no_examples(self):
],
hotel=None,
)
l = description_lib.DescribeStructure(
l = description_lib._DescribeStructure(
input=value, context='1 day itinerary to SF'
)
self.assertEqual(
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_render_no_context(self):
],
hotel=None,
)
l = description_lib.DescribeStructure(input=value)
l = description_lib._DescribeStructure(input=value)
self.assertEqual(
l.render().text,
inspect.cleandoc("""
Expand Down
6 changes: 3 additions & 3 deletions langfun/core/structured/function_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from langfun.core import language_model
from langfun.core import template
from langfun.core.coding import python
from langfun.core.structured import prompting
from langfun.core.structured import query as query_lib
import pyglove as pg


Expand All @@ -39,7 +39,7 @@ class PythonFunctionSignature(pg.Object):

unittest_examples = None
for _ in range(num_retries):
r = prompting.query(
r = query_lib.query(
PythonFunctionSignature(signature=signature),
list[UnitTest],
lm=lm,
Expand Down Expand Up @@ -145,7 +145,7 @@ def calculate_area_circle(radius: float) -> float:
last_error = None
for _ in range(num_retries):
try:
source_code = prompting.query(
source_code = query_lib.query(
PythonFunctionPrompt(signature=signature), lm=lm
)
f = python.evaluate(source_code, global_vars=context)
Expand Down
16 changes: 8 additions & 8 deletions langfun/core/structured/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

import langfun.core as lf
from langfun.core.structured import mapping
from langfun.core.structured import prompting
from langfun.core.structured import query as query_lib
from langfun.core.structured import schema as schema_lib
import pyglove as pg


@lf.use_init_args(['schema', 'default', 'examples'])
class ParseStructure(mapping.Mapping):
class _ParseStructure(mapping.Mapping):
"""Parse an object out from a natural language text."""

context_title = 'USER_REQUEST'
Expand All @@ -37,7 +37,7 @@ class ParseStructure(mapping.Mapping):
]


class ParseStructureJson(ParseStructure):
class _ParseStructureJson(_ParseStructure):
"""Parse an object out from a NL text using JSON as the protocol."""

preamble = """
Expand All @@ -53,7 +53,7 @@ class ParseStructureJson(ParseStructure):
output_title = 'JSON'


class ParseStructurePython(ParseStructure):
class _ParseStructurePython(_ParseStructure):
"""Parse an object out from a NL text using Python as the protocol."""

preamble = """
Expand Down Expand Up @@ -271,7 +271,7 @@ def call(
return lm_output if returns_message else lm_output.text

# Call `parsing_lm` for structured parsing.
parsing_message = prompting.query(
parsing_message = query_lib.query(
lm_output.text,
schema,
examples=parsing_examples,
Expand All @@ -293,11 +293,11 @@ def call(

def _parse_structure_cls(
protocol: schema_lib.SchemaProtocol,
) -> Type[ParseStructure]:
) -> Type[_ParseStructure]:
if protocol == 'json':
return ParseStructureJson
return _ParseStructureJson
elif protocol == 'python':
return ParseStructurePython
return _ParseStructurePython
else:
raise ValueError(f'Unknown protocol: {protocol!r}.')

Expand Down
16 changes: 8 additions & 8 deletions langfun/core/structured/parsing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Itinerary(pg.Object):
class ParseStructurePythonTest(unittest.TestCase):

def test_render_no_examples(self):
l = parsing.ParseStructurePython(int)
l = parsing._ParseStructurePython(int)
m = lf.AIMessage('12 / 6 + 2 = 4')
self.assertEqual(
l.render(input=m, context='Compute 12 / 6 + 2.').text,
Expand All @@ -62,7 +62,7 @@ def test_render_no_examples(self):
)

def test_render_no_context(self):
l = parsing.ParseStructurePython(int)
l = parsing._ParseStructurePython(int)
m = lf.AIMessage('12 / 6 + 2 = 4')

self.assertEqual(
Expand All @@ -85,7 +85,7 @@ def test_render_no_context(self):
)

def test_render(self):
l = parsing.ParseStructurePython(
l = parsing._ParseStructurePython(
int,
examples=[
mapping.MappingExample(
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_invocation(self):
),
override_attrs=True,
):
l = parsing.ParseStructurePython(
l = parsing._ParseStructurePython(
[Itinerary],
examples=[
mapping.MappingExample(
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_parse(self):
class ParseStructureJsonTest(unittest.TestCase):

def test_render_no_examples(self):
l = parsing.ParseStructureJson(int)
l = parsing._ParseStructureJson(int)
m = lf.AIMessage('12 / 6 + 2 = 4')
self.assertEqual(
l.render(input=m, context='Compute 12 / 6 + 2.').text,
Expand All @@ -320,7 +320,7 @@ def test_render_no_examples(self):
)

def test_render_no_context(self):
l = parsing.ParseStructureJson(int)
l = parsing._ParseStructureJson(int)
m = lf.AIMessage('12 / 6 + 2 = 4')

self.assertEqual(
Expand All @@ -343,7 +343,7 @@ def test_render_no_context(self):
)

def test_render(self):
l = parsing.ParseStructureJson(
l = parsing._ParseStructureJson(
int,
examples=[
mapping.MappingExample(
Expand Down Expand Up @@ -504,7 +504,7 @@ def test_invocation(self):
override_attrs=True,
):
message = lf.LangFunc(lm_input)()
l = parsing.ParseStructureJson(
l = parsing._ParseStructureJson(
[Itinerary],
examples=[
mapping.MappingExample(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@lf.use_init_args(['schema', 'default', 'examples'])
class QueryStructure(mapping.Mapping):
class _QueryStructure(mapping.Mapping):
"""Query an object out from a natural language text."""

context_title = 'CONTEXT'
Expand All @@ -38,7 +38,7 @@ class QueryStructure(mapping.Mapping):
]


class QueryStructureJson(QueryStructure):
class _QueryStructureJson(_QueryStructure):
"""Query a structured value using JSON as the protocol."""

preamble = """
Expand All @@ -52,18 +52,18 @@ class QueryStructureJson(QueryStructure):
1 + 1 =
{{ schema_title }}:
{"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": int}}
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
{{ output_title}}:
{"result": {"_type": "langfun.core.structured.prompting.Answer", "final_answer": 2}}
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
"""

protocol = 'json'
schema_title = 'SCHEMA'
output_title = 'JSON'


class QueryStructurePython(QueryStructure):
class _QueryStructurePython(_QueryStructure):
"""Query a structured value using Python as the protocol."""

preamble = """
Expand Down Expand Up @@ -94,11 +94,11 @@ class Answer:

def _query_structure_cls(
protocol: schema_lib.SchemaProtocol,
) -> Type[QueryStructure]:
) -> Type[_QueryStructure]:
if protocol == 'json':
return QueryStructureJson
return _QueryStructureJson
elif protocol == 'python':
return QueryStructurePython
return _QueryStructurePython
else:
raise ValueError(f'Unknown protocol: {protocol!r}.')

Expand Down
Loading

0 comments on commit be31028

Please sign in to comment.