diff --git a/langfun/core/coding/python.py b/langfun/core/coding/python.py index 2db1474..c23d652 100644 --- a/langfun/core/coding/python.py +++ b/langfun/core/coding/python.py @@ -18,6 +18,7 @@ import enum import inspect import io +import re import textwrap from typing import Annotated, Any @@ -95,6 +96,8 @@ def __str__(self): class PythonCodeParser(lf.Component): """Python code parser with permission control.""" + _ID_REGEX = re.compile('^[a-zA-Z_\\-]*$') + class _CodeValidator(ast.NodeVisitor): """Python AST node visitor for ensuring code are permitted.""" @@ -192,9 +195,9 @@ def clean(self, code_text: str) -> str: while i < len(code_text): c = code_text[i] # Detect code block separator (```). - if (quote_char is None + if (not in_comment + and quote_char is None and c == '`' - and i < len(code_text) - 2 and code_text[i:i + 3] == '```'): in_code = not in_code if in_code: @@ -203,29 +206,52 @@ def clean(self, code_text: str) -> str: else: break - if in_code: - code.write(c) - # Detect string literal boundary. if (not in_comment and c in ('\'', '"') and i > 0 and code_text[i - 1] != '\\'): + + # Handle ''' and """. + if code_text[i: i + 3] == c * 3: + c = c * 3 + i += 2 + if quote_char is None: quote_char = c elif quote_char == c: - quote_char = None + # NOTE(daiyip): at times, LM forgets to escape quotes inside a string. + # Thus we do some smart checking here to automatically correct such + # case. + if i < len(code_text) - 1 and code_text[i + 1] not in '.,]}) \t\n+*': + c = f'\\{c}' + else: + quote_char = None # Detect comment. elif c == '#' and quote_char is None: in_comment = True # Detect end-of-comment. elif c == '\n': - in_comment = False + # NOTE(daiyip): deal with cases that LM forgot to escape linebreaks + # within strings. + if quote_char is not None: + # Only add \\ for ' and " (other than ''' and """). + if len(quote_char) == 1: + c = '\\n' + else: + in_comment = False + + if in_code: + code.write(c) + i += 1 code = code.getvalue() if code: - code = code.lstrip('python') + pos = code.find('\n') + # Strip markdown code type. E.g. ```python + if pos > 0 and self._ID_REGEX.match(code[:pos]): + code = code[pos:] else: # Maybe-code that resides not within a code markdown block. code = code_text diff --git a/langfun/core/coding/python_test.py b/langfun/core/coding/python_test.py index 5d1c0f1..6f04bf0 100644 --- a/langfun/core/coding/python_test.py +++ b/langfun/core/coding/python_test.py @@ -122,13 +122,13 @@ def test_clean(self): The code looks as below: ```python - x = y + 1 - z = x * y + x = y + 1 # ``` + z = x * y # " ``` """, """ - x = y + 1 - z = x * y + x = y + 1 # ``` + z = x * y # " """ ) self.assert_clean( @@ -184,6 +184,71 @@ class A: y: str """ ) + self.assert_clean( + """ + ```tool-code + class A: + x: int + y: str + ``` + """, + """ + class A: + x: int + y: str + """ + ) + self.assert_clean( + """ + ``` + class A: + '''Class a. + + Examples: + ``` + A(1, 2) + ``` + ''' + x: int + y: str + ``` + """, + """ + class A: + '''Class a. + + Examples: + ``` + A(1, 2) + ``` + ''' + x: int + y: str + """ + ) + + def test_clean_with_auto_correction(self): + self.assert_clean( + """ + ```python + x = 'John's home' + ``` + """, + """ + x = 'John\\'s home' + """ + ) + self.assert_clean( + """ + ```python + x = 'Hello + World' + ``` + """, + """ + x = 'Hello\\n World' + """ + ) def assert_allowed(self, code: str, permission: python.CodePermission): _, ast = python.PythonCodeParser().parse(code, permission)