diff --git a/langfun/__init__.py b/langfun/__init__.py index fe86aff..1ade81b 100644 --- a/langfun/__init__.py +++ b/langfun/__init__.py @@ -39,6 +39,11 @@ from langfun.core import memories +# Error types. +SchemaError = structured.SchemaError +JsonError = structured.JsonError +CodeError = coding.CodeError + # Placeholder for Google-internal imports. # pylint: enable=unused-import diff --git a/langfun/core/coding/__init__.py b/langfun/core/coding/__init__.py index 829d9f6..fa4338a 100644 --- a/langfun/core/coding/__init__.py +++ b/langfun/core/coding/__init__.py @@ -17,6 +17,8 @@ # pylint: disable=g-importing-member from langfun.core.coding.python import CodePermission +from langfun.core.coding.python import CodeError + from langfun.core.coding.python import PythonCode from langfun.core.coding.python import PythonCodeParser diff --git a/langfun/core/coding/python.py b/langfun/core/coding/python.py index e8c5e92..2db1474 100644 --- a/langfun/core/coding/python.py +++ b/langfun/core/coding/python.py @@ -18,6 +18,7 @@ import enum import inspect import io +import textwrap from typing import Annotated, Any import langfun.core as lf @@ -58,19 +59,64 @@ def ALL(cls) -> 'CodePermission': # pylint: disable=invalid-name CodePermission.FUNCTION_DEFINITION | CodePermission.IMPORT) +class CodeError(RuntimeError): + """Python code error.""" + + def __init__(self, code: str, cause: Exception): + self.code = code + self.cause = cause + + def __str__(self): + r = io.StringIO() + r.write( + lf.colored(f'{self.cause.__class__.__name__}: {self.cause}', 'magenta')) + + if isinstance(self.cause, SyntaxError): + r.write('\n\n') + r.write(textwrap.indent( + lf.colored(self.cause.text, 'magenta'), + ' ' * 2 + )) + if not self.cause.text.endswith('\n'): + r.write('\n') + + r.write('\n') + r.write(lf.colored('Generated Code:', 'red')) + r.write('\n\n') + r.write(lf.colored(' ```python\n', 'magenta')) + r.write(textwrap.indent( + lf.colored(self.code, 'magenta'), + ' ' * 2 + )) + r.write(lf.colored('\n ```\n', 'magenta')) + return r.getvalue() + + class PythonCodeParser(lf.Component): """Python code parser with permission control.""" class _CodeValidator(ast.NodeVisitor): """Python AST node visitor for ensuring code are permitted.""" - def __init__(self, perm: CodePermission): + def __init__(self, code: str, perm: CodePermission): super().__init__() + self.code = code self.perm = perm def verify(self, node, flag: CodePermission, node_type, error_message: str): if isinstance(node, node_type) and not (self.perm & flag): - raise SyntaxError(error_message) + raise SyntaxError( + error_message, ( + '', + node.lineno, + node.col_offset, + self._code_line(node.lineno), + node.end_lineno, + node.end_col_offset, + )) + + def _code_line(self, lineno): + return self.code.split('\n')[lineno - 1] def generic_visit(self, node): self.verify( @@ -96,7 +142,7 @@ def generic_visit(self, node): ast.Raise, ast.Assert ), - 'Class definition is not allowed.') + 'Exception is not allowed.') self.verify( node, @@ -127,10 +173,14 @@ def generic_visit(self, node): super().generic_visit(node) - def parse(self, code_text: str, perm: CodePermission) -> ast.AST: - code_block = ast.parse(self.clean(code_text), mode='exec') - PythonCodeParser._CodeValidator(perm).visit(code_block) - return code_block + def parse(self, code: str, perm: CodePermission) -> tuple[str, ast.AST]: + code = self.clean(code) + try: + parsed_code = ast.parse(code, mode='exec') + PythonCodeParser._CodeValidator(code, perm).visit(parsed_code) + except SyntaxError as e: + raise CodeError(code, e) from e + return code, parsed_code def clean(self, code_text: str) -> str: # TODO(daiyip): Deal with markdown in docstrings. @@ -273,8 +323,9 @@ def run( ctx.update(kwargs) # Parse the code str. - code_block = PythonCodeParser().parse(code, perm) + code, code_block = PythonCodeParser().parse(code, perm) global_vars, local_vars = ctx, {} + if hasattr(code_block.body[-1], 'value'): last_expr = code_block.body.pop() # pytype: disable=attribute-error result_vars = [_FINAL_RESULT_KEY] @@ -285,15 +336,23 @@ def run( last_expr = ast.Expression(last_expr.value) # pytype: disable=attribute-error - # Execute the lines before the last expression. - exec(compile(code_block, '', mode='exec'), global_vars, local_vars) # pylint: disable=exec-used + try: + # Execute the lines before the last expression. + exec(compile(code_block, '', mode='exec'), global_vars, local_vars) # pylint: disable=exec-used + + # Evaluate the last expression. + result = eval( # pylint: disable=eval-used + compile(last_expr, '', mode='eval'), global_vars, local_vars) + except Exception as e: + raise CodeError(code, e) from e - # Evaluate the last expression. - result = eval(compile(last_expr, '', mode='eval'), global_vars, local_vars) # pylint: disable=eval-used for result_var in result_vars: local_vars[result_var] = result else: - exec(compile(code_block, '', mode='exec'), global_vars, local_vars) # pylint: disable=exec-used + try: + exec(compile(code_block, '', mode='exec'), global_vars, local_vars) # pylint: disable=exec-used + except Exception as e: + raise CodeError(code, e) from e local_vars[_FINAL_RESULT_KEY] = list(local_vars.values())[-1] return local_vars diff --git a/langfun/core/coding/python_test.py b/langfun/core/coding/python_test.py index ca5867d..5d1c0f1 100644 --- a/langfun/core/coding/python_test.py +++ b/langfun/core/coding/python_test.py @@ -186,10 +186,11 @@ class A: ) def assert_allowed(self, code: str, permission: python.CodePermission): - self.assertIsNotNone(python.PythonCodeParser().parse(code, permission)) + _, ast = python.PythonCodeParser().parse(code, permission) + self.assertIsNotNone(ast) def assert_not_allowed(self, code: str, permission: python.CodePermission): - with self.assertRaisesRegex(SyntaxError, '.* is not allowed'): + with self.assertRaisesRegex(python.CodeError, '.* is not allowed'): python.PythonCodeParser().parse(code, permission) def test_parse_with_allowed_code(self): @@ -365,6 +366,17 @@ def foo(x, y): self.assertIsInstance(ret['k'], pg.Object) self.assertEqual(ret['__result__'], 10) + def test_run_with_error(self): + with self.assertRaisesRegex( + python.CodeError, 'NameError: name .* is not defined'): + python.run( + """ + x = 1 + y = x + z + """, + python.CodePermission.ALL + ) + class PythonCodeTest(unittest.TestCase): diff --git a/langfun/core/structured/__init__.py b/langfun/core/structured/__init__.py index c6d66e5..8f03f60 100644 --- a/langfun/core/structured/__init__.py +++ b/langfun/core/structured/__init__.py @@ -25,6 +25,9 @@ from langfun.core.structured.schema import SchemaProtocol from langfun.core.structured.schema import schema_spec +from langfun.core.structured.schema import SchemaError +from langfun.core.structured.schema import JsonError + from langfun.core.structured.schema import class_dependencies from langfun.core.structured.schema import class_definition from langfun.core.structured.schema import class_definitions @@ -43,7 +46,6 @@ from langfun.core.structured.mapping import Mapping from langfun.core.structured.mapping import MappingExample -from langfun.core.structured.mapping import MappingError # Mappings of between different forms of content. from langfun.core.structured.mapping import NaturalLanguageToStructure diff --git a/langfun/core/structured/completion_test.py b/langfun/core/structured/completion_test.py index dfd8f8b..345fb2a 100644 --- a/langfun/core/structured/completion_test.py +++ b/langfun/core/structured/completion_test.py @@ -17,6 +17,7 @@ import unittest import langfun.core as lf +from langfun.core import coding from langfun.core.llms import fake from langfun.core.structured import completion from langfun.core.structured import mapping @@ -409,8 +410,8 @@ def test_bad_transform(self): override_attrs=True, ): with self.assertRaisesRegex( - mapping.MappingError, - 'Cannot parse message text into structured output', + coding.CodeError, + 'Expect .* but encountered .*', ): completion.complete(Activity.partial()) diff --git a/langfun/core/structured/description_test.py b/langfun/core/structured/description_test.py index 5715962..8f4ed1b 100644 --- a/langfun/core/structured/description_test.py +++ b/langfun/core/structured/description_test.py @@ -110,7 +110,7 @@ def test_render(self): description='Visit Golden Gate Bridge.' ), Activity( - description='Visit Fisherman's Wharf.' + description="Visit Fisherman's Wharf." ), Activity( description='Visit Alcatraz Island.' @@ -138,7 +138,6 @@ def test_render_no_examples(self): hotel=None, ), ) - self.assertEqual( l.render(message=m).text, inspect.cleandoc(""" @@ -160,7 +159,7 @@ def test_render_no_examples(self): description='Visit Golden Gate Bridge.' ), Activity( - description='Visit Fisherman's Wharf.' + description="Visit Fisherman's Wharf." ), Activity( description='Visit Alcatraz Island.' @@ -207,7 +206,7 @@ def test_render_no_context(self): description='Visit Golden Gate Bridge.' ), Activity( - description='Visit Fisherman's Wharf.' + description="Visit Fisherman's Wharf." ), Activity( description='Visit Alcatraz Island.' diff --git a/langfun/core/structured/mapping.py b/langfun/core/structured/mapping.py index 809ffba..206c0aa 100644 --- a/langfun/core/structured/mapping.py +++ b/langfun/core/structured/mapping.py @@ -21,16 +21,6 @@ import pyglove as pg -class MappingError(ValueError): # pylint: disable=g-bad-exception-name - """Mappingg error.""" - - def __eq__(self, other): - return isinstance(other, MappingError) and self.args == other.args - - def __ne__(self, other): - return not self.__eq__(other) - - # NOTE(daiyip): We put `schema` at last as it could inherit from the parent # objects. @pg.use_init_args(['nl_context', 'nl_text', 'value', 'schema']) @@ -245,10 +235,7 @@ def transform_output(self, lm_output: lf.Message) -> lf.Message: ) except Exception as e: # pylint: disable=broad-exception-caught if self.default == lf.message_transform.RAISE_IF_HAS_ERROR: - raise MappingError( - 'Cannot parse message text into structured output. ' - f'Error={e}. Text={lm_output.text!r}.' - ) from e + raise e lm_output.result = self.default return lm_output @@ -416,10 +403,7 @@ def transform_output(self, lm_output: lf.Message) -> lf.Message: ) except Exception as e: # pylint: disable=broad-exception-caught if self.default == lf.message_transform.RAISE_IF_HAS_ERROR: - raise MappingError( - 'Cannot parse message text into structured output. ' - f'Error={e}. Text={lm_output.text!r}.' - ) from e + raise e result = self.default lm_output.result = result return lm_output diff --git a/langfun/core/structured/mapping_test.py b/langfun/core/structured/mapping_test.py index 8fd8229..40b0ae1 100644 --- a/langfun/core/structured/mapping_test.py +++ b/langfun/core/structured/mapping_test.py @@ -21,15 +21,6 @@ import pyglove as pg -class MappingErrorTest(unittest.TestCase): - - def test_eq(self): - self.assertEqual( - mapping.MappingError('Parse failed.'), - mapping.MappingError('Parse failed.') - ) - - class MappingExampleTest(unittest.TestCase): def test_basics(self): diff --git a/langfun/core/structured/parsing_test.py b/langfun/core/structured/parsing_test.py index 1ea440c..8f8bd27 100644 --- a/langfun/core/structured/parsing_test.py +++ b/langfun/core/structured/parsing_test.py @@ -17,7 +17,9 @@ import unittest import langfun.core as lf +from langfun.core import coding from langfun.core.llms import fake + from langfun.core.structured import mapping from langfun.core.structured import parsing import pyglove as pg @@ -241,8 +243,8 @@ def test_bad_transform(self): override_attrs=True, ): with self.assertRaisesRegex( - mapping.MappingError, - 'Cannot parse message text into structured output', + coding.CodeError, + 'name .* is not defined', ): lf.LangFunc('Compute 1 + 2').as_structured(int)() @@ -506,8 +508,8 @@ def test_bad_transform(self): override_attrs=True, ): with self.assertRaisesRegex( - mapping.MappingError, - 'Cannot parse message text into structured output', + coding.CodeError, + 'invalid syntax', ): lf.LangFunc('Compute 1 + 2').as_structured(int)() diff --git a/langfun/core/structured/prompting_test.py b/langfun/core/structured/prompting_test.py index 4058df7..384384a 100644 --- a/langfun/core/structured/prompting_test.py +++ b/langfun/core/structured/prompting_test.py @@ -17,9 +17,11 @@ import unittest import langfun.core as lf +from langfun.core import coding from langfun.core.llms import fake from langfun.core.structured import mapping from langfun.core.structured import prompting +from langfun.core.structured import schema as schema_lib import pyglove as pg @@ -178,8 +180,8 @@ def test_bad_transform(self): override_attrs=True, ): with self.assertRaisesRegex( - mapping.MappingError, - 'Cannot parse message text into structured output', + coding.CodeError, + 'name .* is not defined', ): prompting.query('Compute 1 + 2', int) @@ -374,8 +376,8 @@ def test_bad_transform(self): override_attrs=True, ): with self.assertRaisesRegex( - mapping.MappingError, - 'Cannot parse message text into structured output', + schema_lib.JsonError, + 'No JSON dict in the output', ): prompting.query('Compute 1 + 2', int, protocol='json') diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index c1c70ee..a66546d 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -16,6 +16,7 @@ import abc import inspect import io +import textwrap import typing from typing import Any, Literal, Sequence, Type, Union import langfun.core as lf @@ -66,6 +67,41 @@ def _parse_node(v) -> pg.typing.ValueSpec: SchemaProtocol = Literal['json', 'python'] +class SchemaError(Exception): # pylint: disable=g-bad-exception-name + """Schema error.""" + + def __init__(self, + schema: 'Schema', + value: Any, + protocol: SchemaProtocol, + cause: Exception): + self.schema = schema + self.value = value + self.protocol = protocol + self.cause = cause + + def __str__(self): + r = io.StringIO() + r.write( + lf.colored(f'{self.cause.__class__.__name__}: {self.cause}', 'magenta')) + + r.write('\n') + r.write(lf.colored('Schema:', 'red')) + r.write('\n\n') + r.write(textwrap.indent( + lf.colored(schema_repr(self.protocol).repr(self.schema), 'magenta'), + ' ' * 2 + )) + r.write('\n\n') + r.write(lf.colored('Generated value:', 'red')) + r.write('\n\n') + r.write(textwrap.indent( + lf.colored(value_repr(self.protocol).repr(self.value), 'magenta'), + ' ' * 2 + )) + return r.getvalue() + + class Schema(lf.NaturalLanguageFormattable, pg.Object): """Base class for structured data schema.""" @@ -91,7 +127,12 @@ def parse( self, text: str, protocol: SchemaProtocol = 'json', **kwargs ) -> Any: """Parse a LM generated text into a structured value.""" - return self.spec.apply(value_repr(protocol).parse(text, self, **kwargs)) + value = value_repr(protocol).parse(text, self, **kwargs) + + try: + return self.spec.apply(value) + except Exception as e: + raise SchemaError(self, value, protocol, e) # pylint: disable=raise-missing-from def natural_language_format(self) -> str: return self.schema_str() @@ -529,6 +570,25 @@ def structure_from_python(code: str, **symbols) -> Any: return lf_coding.run(code)['__result__'] +class JsonError(Exception): + """Json parsing error.""" + + def __init__(self, json: str, cause: Exception): + self.json = json + self.cause = cause + + def __str__(self) -> str: + r = io.StringIO() + r.write( + lf.colored(f'{self.cause.__class__.__name__}: {self.cause}', 'magenta')) + + r.write('\n\n') + r.write(lf.colored('JSON text:', 'red')) + r.write('\n\n') + r.write(textwrap.indent(lf.colored(self.json, 'magenta'), ' ' * 2)) + return r.getvalue() + + class ValueJsonRepr(ValueRepr): """JSON-representation for value.""" @@ -539,12 +599,17 @@ def repr(self, value: Any, schema: Schema | None = None, **kwargs) -> str: def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any: """Parse a JSON string into a structured object.""" del schema - v = pg.from_json_str(self._cleanup_json(text)) + try: + text = self._cleanup_json(text) + v = pg.from_json_str(text) + except Exception as e: + raise JsonError(text, e) # pylint: disable=raise-missing-from + if not isinstance(v, dict) or 'result' not in v: - raise ValueError( + raise JsonError(text, ValueError( 'The root node of the JSON must be a dict with key `result`. ' f'Encountered: {v}' - ) + )) return v['result'] def _cleanup_json(self, json_str: str) -> str: diff --git a/langfun/core/structured/schema_test.py b/langfun/core/structured/schema_test.py index 98a40e7..e9ac8e7 100644 --- a/langfun/core/structured/schema_test.py +++ b/langfun/core/structured/schema_test.py @@ -157,7 +157,8 @@ def test_value_repr(self): def test_parse(self): schema = schema_lib.Schema(int) self.assertEqual(schema.parse('{"result": 1}'), 1) - with self.assertRaisesRegex(TypeError, 'Expect .* but encountered .*'): + with self.assertRaisesRegex( + schema_lib.SchemaError, 'Expect .* but encountered .*'): schema.parse('{"result": "def"}') with self.assertRaisesRegex(ValueError, 'Unsupported protocol'): @@ -576,7 +577,13 @@ def test_parse_basics(self): Activity('play'), ) with self.assertRaisesRegex( - ValueError, 'The root node of the JSON must be a dict with key `result`' + schema_lib.JsonError, 'JSONDecodeError' + ): + schema_lib.ValueJsonRepr().parse('{"abc", 1}') + + with self.assertRaisesRegex( + schema_lib.JsonError, + 'The root node of the JSON must be a dict with key `result`' ): schema_lib.ValueJsonRepr().parse('{"abc": 1}') @@ -596,12 +603,13 @@ def test_parse_with_new_lines(self): def test_parse_with_malformated_json(self): with self.assertRaisesRegex( - ValueError, 'No JSON dict in the output' + schema_lib.JsonError, 'No JSON dict in the output' ): schema_lib.ValueJsonRepr().parse('The answer is 1.') with self.assertRaisesRegex( - ValueError, 'Malformated JSON: missing .* closing curly braces' + schema_lib.JsonError, + 'Malformated JSON: missing .* closing curly braces' ): schema_lib.ValueJsonRepr().parse('{"result": 1')