Skip to content

Commit

Permalink
Introduce __validate__: User feedback for lf.query on autofix.
Browse files Browse the repository at this point in the history
Example:

```
 class Foo(pg.Object):
      x: int

      def __validate__(self):
        # This triggers autofix.
        if self.x > 1:
          raise ValueError('value should be less or equal than 1.')

        # This triggers rule-based fix (user modification)
        if self.x < 0:
          self.rebind(x=0, skip_notification=True)

   foo = lf.query('Generate a Foo object', Foo, lm=lm, autofix=2)
   assert foo.x in (0, 1)
```

PiperOrigin-RevId: 672124993
  • Loading branch information
daiyip authored and langfun authors committed Sep 7, 2024
1 parent 77a893e commit 5cb7c03
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 12 deletions.
40 changes: 28 additions & 12 deletions langfun/core/coding/python/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,27 @@ def run_with_correction(
# pylint: enable=g-import-not-at-top

if max_attempts == 0:
result = execution.run(
code,
global_vars=global_vars,
sandbox=sandbox,
timeout=timeout,
outputs_intermediate=outputs_intermediate,
result = _maybe_custom_validate(
execution.run(
code,
global_vars=global_vars,
sandbox=sandbox,
timeout=timeout,
outputs_intermediate=outputs_intermediate,
)
)
return (result, code) if returns_code else result

def result_and_error(code: str) -> tuple[Any, str | None]:
try:
result = execution.run(
code,
global_vars=global_vars,
sandbox=sandbox,
timeout=timeout,
outputs_intermediate=outputs_intermediate,
result = _maybe_custom_validate(
execution.run(
code,
global_vars=global_vars,
sandbox=sandbox,
timeout=timeout,
outputs_intermediate=outputs_intermediate,
)
)
return (result, None)
except Exception as e: # pylint: disable=broad-exception-caught
Expand Down Expand Up @@ -190,3 +194,15 @@ def _error_feedback_str(error: Exception) -> str:
)
else:
return f"Encountered {error.__class__.__name__}: {error}"


def _maybe_custom_validate(result: Any) -> Any:
"""Apply custom validation through __validate_generation__ method."""
if isinstance(result, dict) and "__result__" in result:
r = result["__result__"]
else:
r = result

if hasattr(r, "__validate__"):
r.__validate__()
return result
27 changes: 27 additions & 0 deletions langfun/core/coding/python/correction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from langfun.core.coding.python import correction
from langfun.core.coding.python import errors
from langfun.core.llms import fake
import pyglove as pg


class RunWithCorrectionTest(unittest.TestCase):
Expand All @@ -45,6 +46,32 @@ def test_run_with_correction(self):
)
self.assertEqual(result, 4)

def test_run_with_correction_upon_custom_validation(self):

class Foo(pg.Object):
x: int

def __validate__(self):
if self.x > 1:
raise ValueError('value should be less or equal than 1.')
if self.x < 0:
self.rebind(x=0, skip_notification=True)

result = correction.run_with_correction(
inspect.cleandoc("""
Foo(x=2)
"""),
global_vars=dict(Foo=Foo),
lm=fake.StaticSequence([
inspect.cleandoc("""
CorrectedCode(
corrected_code='Foo(x=-1)',
)
"""),
]),
)
self.assertEqual(result, Foo(0))

def test_run_without_correction(self):
result = correction.run_with_correction(
inspect.cleandoc("""
Expand Down

0 comments on commit 5cb7c03

Please sign in to comment.