Skip to content

Commit

Permalink
Automatically correct bad Python code when LMs forget to escape quote…
Browse files Browse the repository at this point in the history
…s and newline.

This is one of the major failure pattern we encountered with symbolic prompting.

PiperOrigin-RevId: 571567898
  • Loading branch information
daiyip authored and langfun authors committed Oct 7, 2023
1 parent 2fbb846 commit 1ed40e7
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 12 deletions.
42 changes: 34 additions & 8 deletions langfun/core/coding/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import enum
import inspect
import io
import re
import textwrap
from typing import Annotated, Any

Expand Down Expand Up @@ -95,6 +96,8 @@ def __str__(self):
class PythonCodeParser(lf.Component):
"""Python code parser with permission control."""

_ID_REGEX = re.compile('^[a-zA-Z_\\-]*$')

class _CodeValidator(ast.NodeVisitor):
"""Python AST node visitor for ensuring code are permitted."""

Expand Down Expand Up @@ -192,9 +195,9 @@ def clean(self, code_text: str) -> str:
while i < len(code_text):
c = code_text[i]
# Detect code block separator (```).
if (quote_char is None
if (not in_comment
and quote_char is None
and c == '`'
and i < len(code_text) - 2
and code_text[i:i + 3] == '```'):
in_code = not in_code
if in_code:
Expand All @@ -203,29 +206,52 @@ def clean(self, code_text: str) -> str:
else:
break

if in_code:
code.write(c)

# Detect string literal boundary.
if (not in_comment
and c in ('\'', '"')
and i > 0
and code_text[i - 1] != '\\'):

# Handle ''' and """.
if code_text[i: i + 3] == c * 3:
c = c * 3
i += 2

if quote_char is None:
quote_char = c
elif quote_char == c:
quote_char = None
# NOTE(daiyip): at times, LM forgets to escape quotes inside a string.
# Thus we do some smart checking here to automatically correct such
# case.
if i < len(code_text) - 1 and code_text[i + 1] not in '.,]}) \t\n+*':
c = f'\\{c}'
else:
quote_char = None
# Detect comment.
elif c == '#' and quote_char is None:
in_comment = True
# Detect end-of-comment.
elif c == '\n':
in_comment = False
# NOTE(daiyip): deal with cases that LM forgot to escape linebreaks
# within strings.
if quote_char is not None:
# Only add \\ for ' and " (other than ''' and """).
if len(quote_char) == 1:
c = '\\n'
else:
in_comment = False

if in_code:
code.write(c)

i += 1

code = code.getvalue()
if code:
code = code.lstrip('python')
pos = code.find('\n')
# Strip markdown code type. E.g. ```python
if pos > 0 and self._ID_REGEX.match(code[:pos]):
code = code[pos:]
else:
# Maybe-code that resides not within a code markdown block.
code = code_text
Expand Down
73 changes: 69 additions & 4 deletions langfun/core/coding/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ def test_clean(self):
The code looks as below:
```python
x = y + 1
z = x * y
x = y + 1 # ```
z = x * y # "
```
""",
"""
x = y + 1
z = x * y
x = y + 1 # ```
z = x * y # "
"""
)
self.assert_clean(
Expand Down Expand Up @@ -184,6 +184,71 @@ class A:
y: str
"""
)
self.assert_clean(
"""
```tool-code
class A:
x: int
y: str
```
""",
"""
class A:
x: int
y: str
"""
)
self.assert_clean(
"""
```
class A:
'''Class a.
Examples:
```
A(1, 2)
```
'''
x: int
y: str
```
""",
"""
class A:
'''Class a.
Examples:
```
A(1, 2)
```
'''
x: int
y: str
"""
)

def test_clean_with_auto_correction(self):
self.assert_clean(
"""
```python
x = 'John's home'
```
""",
"""
x = 'John\\'s home'
"""
)
self.assert_clean(
"""
```python
x = 'Hello
World'
```
""",
"""
x = 'Hello\\n World'
"""
)

def assert_allowed(self, code: str, permission: python.CodePermission):
_, ast = python.PythonCodeParser().parse(code, permission)
Expand Down

0 comments on commit 1ed40e7

Please sign in to comment.