Skip to content

Commit

Permalink
Improve error reporting for LM generated Python/JSON.
Browse files Browse the repository at this point in the history
- Introduce `lf.CodeError`, `lf.JSONError` and `lf.SchemaError`.
- Colorful rendering of generated code and JSON for easy debugging.

PiperOrigin-RevId: 569523585
  • Loading branch information
daiyip authored and langfun authors committed Sep 29, 2023
1 parent ad59934 commit 3f56a1d
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 65 deletions.
5 changes: 5 additions & 0 deletions langfun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@

from langfun.core import memories

# Error types.
SchemaError = structured.SchemaError
JsonError = structured.JsonError
CodeError = coding.CodeError

# Placeholder for Google-internal imports.

# pylint: enable=unused-import
Expand Down
2 changes: 2 additions & 0 deletions langfun/core/coding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# pylint: disable=g-importing-member

from langfun.core.coding.python import CodePermission
from langfun.core.coding.python import CodeError

from langfun.core.coding.python import PythonCode
from langfun.core.coding.python import PythonCodeParser

Expand Down
85 changes: 72 additions & 13 deletions langfun/core/coding/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import enum
import inspect
import io
import textwrap
from typing import Annotated, Any

import langfun.core as lf
Expand Down Expand Up @@ -58,19 +59,64 @@ def ALL(cls) -> 'CodePermission': # pylint: disable=invalid-name
CodePermission.FUNCTION_DEFINITION | CodePermission.IMPORT)


class CodeError(RuntimeError):
"""Python code error."""

def __init__(self, code: str, cause: Exception):
self.code = code
self.cause = cause

def __str__(self):
r = io.StringIO()
r.write(
lf.colored(f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'))

if isinstance(self.cause, SyntaxError):
r.write('\n\n')
r.write(textwrap.indent(
lf.colored(self.cause.text, 'magenta'),
' ' * 2
))
if not self.cause.text.endswith('\n'):
r.write('\n')

r.write('\n')
r.write(lf.colored('Generated Code:', 'red'))
r.write('\n\n')
r.write(lf.colored(' ```python\n', 'magenta'))
r.write(textwrap.indent(
lf.colored(self.code, 'magenta'),
' ' * 2
))
r.write(lf.colored('\n ```\n', 'magenta'))
return r.getvalue()


class PythonCodeParser(lf.Component):
"""Python code parser with permission control."""

class _CodeValidator(ast.NodeVisitor):
"""Python AST node visitor for ensuring code are permitted."""

def __init__(self, perm: CodePermission):
def __init__(self, code: str, perm: CodePermission):
super().__init__()
self.code = code
self.perm = perm

def verify(self, node, flag: CodePermission, node_type, error_message: str):
if isinstance(node, node_type) and not (self.perm & flag):
raise SyntaxError(error_message)
raise SyntaxError(
error_message, (
'<generated-code>',
node.lineno,
node.col_offset,
self._code_line(node.lineno),
node.end_lineno,
node.end_col_offset,
))

def _code_line(self, lineno):
return self.code.split('\n')[lineno - 1]

def generic_visit(self, node):
self.verify(
Expand All @@ -96,7 +142,7 @@ def generic_visit(self, node):
ast.Raise,
ast.Assert
),
'Class definition is not allowed.')
'Exception is not allowed.')

self.verify(
node,
Expand Down Expand Up @@ -127,10 +173,14 @@ def generic_visit(self, node):

super().generic_visit(node)

def parse(self, code_text: str, perm: CodePermission) -> ast.AST:
code_block = ast.parse(self.clean(code_text), mode='exec')
PythonCodeParser._CodeValidator(perm).visit(code_block)
return code_block
def parse(self, code: str, perm: CodePermission) -> tuple[str, ast.AST]:
code = self.clean(code)
try:
parsed_code = ast.parse(code, mode='exec')
PythonCodeParser._CodeValidator(code, perm).visit(parsed_code)
except SyntaxError as e:
raise CodeError(code, e) from e
return code, parsed_code

def clean(self, code_text: str) -> str:
# TODO(daiyip): Deal with markdown in docstrings.
Expand Down Expand Up @@ -273,8 +323,9 @@ def run(
ctx.update(kwargs)

# Parse the code str.
code_block = PythonCodeParser().parse(code, perm)
code, code_block = PythonCodeParser().parse(code, perm)
global_vars, local_vars = ctx, {}

if hasattr(code_block.body[-1], 'value'):
last_expr = code_block.body.pop() # pytype: disable=attribute-error
result_vars = [_FINAL_RESULT_KEY]
Expand All @@ -285,15 +336,23 @@ def run(

last_expr = ast.Expression(last_expr.value) # pytype: disable=attribute-error

# Execute the lines before the last expression.
exec(compile(code_block, '', mode='exec'), global_vars, local_vars) # pylint: disable=exec-used
try:
# Execute the lines before the last expression.
exec(compile(code_block, '', mode='exec'), global_vars, local_vars) # pylint: disable=exec-used

# Evaluate the last expression.
result = eval( # pylint: disable=eval-used
compile(last_expr, '', mode='eval'), global_vars, local_vars)
except Exception as e:
raise CodeError(code, e) from e

# Evaluate the last expression.
result = eval(compile(last_expr, '', mode='eval'), global_vars, local_vars) # pylint: disable=eval-used
for result_var in result_vars:
local_vars[result_var] = result
else:
exec(compile(code_block, '', mode='exec'), global_vars, local_vars) # pylint: disable=exec-used
try:
exec(compile(code_block, '', mode='exec'), global_vars, local_vars) # pylint: disable=exec-used
except Exception as e:
raise CodeError(code, e) from e
local_vars[_FINAL_RESULT_KEY] = list(local_vars.values())[-1]
return local_vars

Expand Down
16 changes: 14 additions & 2 deletions langfun/core/coding/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,11 @@ class A:
)

def assert_allowed(self, code: str, permission: python.CodePermission):
self.assertIsNotNone(python.PythonCodeParser().parse(code, permission))
_, ast = python.PythonCodeParser().parse(code, permission)
self.assertIsNotNone(ast)

def assert_not_allowed(self, code: str, permission: python.CodePermission):
with self.assertRaisesRegex(SyntaxError, '.* is not allowed'):
with self.assertRaisesRegex(python.CodeError, '.* is not allowed'):
python.PythonCodeParser().parse(code, permission)

def test_parse_with_allowed_code(self):
Expand Down Expand Up @@ -365,6 +366,17 @@ def foo(x, y):
self.assertIsInstance(ret['k'], pg.Object)
self.assertEqual(ret['__result__'], 10)

def test_run_with_error(self):
with self.assertRaisesRegex(
python.CodeError, 'NameError: name .* is not defined'):
python.run(
"""
x = 1
y = x + z
""",
python.CodePermission.ALL
)


class PythonCodeTest(unittest.TestCase):

Expand Down
4 changes: 3 additions & 1 deletion langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from langfun.core.structured.schema import SchemaProtocol
from langfun.core.structured.schema import schema_spec

from langfun.core.structured.schema import SchemaError
from langfun.core.structured.schema import JsonError

from langfun.core.structured.schema import class_dependencies
from langfun.core.structured.schema import class_definition
from langfun.core.structured.schema import class_definitions
Expand All @@ -43,7 +46,6 @@

from langfun.core.structured.mapping import Mapping
from langfun.core.structured.mapping import MappingExample
from langfun.core.structured.mapping import MappingError

# Mappings of between different forms of content.
from langfun.core.structured.mapping import NaturalLanguageToStructure
Expand Down
5 changes: 3 additions & 2 deletions langfun/core/structured/completion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import langfun.core as lf
from langfun.core import coding
from langfun.core.llms import fake
from langfun.core.structured import completion
from langfun.core.structured import mapping
Expand Down Expand Up @@ -409,8 +410,8 @@ def test_bad_transform(self):
override_attrs=True,
):
with self.assertRaisesRegex(
mapping.MappingError,
'Cannot parse message text into structured output',
coding.CodeError,
'Expect .* but encountered .*',
):
completion.complete(Activity.partial())

Expand Down
7 changes: 3 additions & 4 deletions langfun/core/structured/description_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_render(self):
description='Visit Golden Gate Bridge.'
),
Activity(
description='Visit Fisherman's Wharf.'
description="Visit Fisherman's Wharf."
),
Activity(
description='Visit Alcatraz Island.'
Expand Down Expand Up @@ -138,7 +138,6 @@ def test_render_no_examples(self):
hotel=None,
),
)

self.assertEqual(
l.render(message=m).text,
inspect.cleandoc("""
Expand All @@ -160,7 +159,7 @@ def test_render_no_examples(self):
description='Visit Golden Gate Bridge.'
),
Activity(
description='Visit Fisherman's Wharf.'
description="Visit Fisherman's Wharf."
),
Activity(
description='Visit Alcatraz Island.'
Expand Down Expand Up @@ -207,7 +206,7 @@ def test_render_no_context(self):
description='Visit Golden Gate Bridge.'
),
Activity(
description='Visit Fisherman's Wharf.'
description="Visit Fisherman's Wharf."
),
Activity(
description='Visit Alcatraz Island.'
Expand Down
20 changes: 2 additions & 18 deletions langfun/core/structured/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@
import pyglove as pg


class MappingError(ValueError): # pylint: disable=g-bad-exception-name
"""Mappingg error."""

def __eq__(self, other):
return isinstance(other, MappingError) and self.args == other.args

def __ne__(self, other):
return not self.__eq__(other)


# NOTE(daiyip): We put `schema` at last as it could inherit from the parent
# objects.
@pg.use_init_args(['nl_context', 'nl_text', 'value', 'schema'])
Expand Down Expand Up @@ -245,10 +235,7 @@ def transform_output(self, lm_output: lf.Message) -> lf.Message:
)
except Exception as e: # pylint: disable=broad-exception-caught
if self.default == lf.message_transform.RAISE_IF_HAS_ERROR:
raise MappingError(
'Cannot parse message text into structured output. '
f'Error={e}. Text={lm_output.text!r}.'
) from e
raise e
lm_output.result = self.default
return lm_output

Expand Down Expand Up @@ -416,10 +403,7 @@ def transform_output(self, lm_output: lf.Message) -> lf.Message:
)
except Exception as e: # pylint: disable=broad-exception-caught
if self.default == lf.message_transform.RAISE_IF_HAS_ERROR:
raise MappingError(
'Cannot parse message text into structured output. '
f'Error={e}. Text={lm_output.text!r}.'
) from e
raise e
result = self.default
lm_output.result = result
return lm_output
9 changes: 0 additions & 9 deletions langfun/core/structured/mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,6 @@
import pyglove as pg


class MappingErrorTest(unittest.TestCase):

def test_eq(self):
self.assertEqual(
mapping.MappingError('Parse failed.'),
mapping.MappingError('Parse failed.')
)


class MappingExampleTest(unittest.TestCase):

def test_basics(self):
Expand Down
10 changes: 6 additions & 4 deletions langfun/core/structured/parsing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import unittest

import langfun.core as lf
from langfun.core import coding
from langfun.core.llms import fake

from langfun.core.structured import mapping
from langfun.core.structured import parsing
import pyglove as pg
Expand Down Expand Up @@ -241,8 +243,8 @@ def test_bad_transform(self):
override_attrs=True,
):
with self.assertRaisesRegex(
mapping.MappingError,
'Cannot parse message text into structured output',
coding.CodeError,
'name .* is not defined',
):
lf.LangFunc('Compute 1 + 2').as_structured(int)()

Expand Down Expand Up @@ -506,8 +508,8 @@ def test_bad_transform(self):
override_attrs=True,
):
with self.assertRaisesRegex(
mapping.MappingError,
'Cannot parse message text into structured output',
coding.CodeError,
'invalid syntax',
):
lf.LangFunc('Compute 1 + 2').as_structured(int)()

Expand Down
10 changes: 6 additions & 4 deletions langfun/core/structured/prompting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import unittest

import langfun.core as lf
from langfun.core import coding
from langfun.core.llms import fake
from langfun.core.structured import mapping
from langfun.core.structured import prompting
from langfun.core.structured import schema as schema_lib
import pyglove as pg


Expand Down Expand Up @@ -178,8 +180,8 @@ def test_bad_transform(self):
override_attrs=True,
):
with self.assertRaisesRegex(
mapping.MappingError,
'Cannot parse message text into structured output',
coding.CodeError,
'name .* is not defined',
):
prompting.query('Compute 1 + 2', int)

Expand Down Expand Up @@ -374,8 +376,8 @@ def test_bad_transform(self):
override_attrs=True,
):
with self.assertRaisesRegex(
mapping.MappingError,
'Cannot parse message text into structured output',
schema_lib.JsonError,
'No JSON dict in the output',
):
prompting.query('Compute 1 + 2', int, protocol='json')

Expand Down
Loading

0 comments on commit 3f56a1d

Please sign in to comment.