Skip to content

Commit

Permalink
lf.function_gen: evaluating LM generated code with the symbols from…
Browse files Browse the repository at this point in the history
… the original function module.

PiperOrigin-RevId: 703704741
  • Loading branch information
daiyip authored and langfun authors committed Dec 7, 2024
1 parent 421b8d5 commit 074e47a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
41 changes: 27 additions & 14 deletions langfun/core/structured/function_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def unittest_with_test_cases(f, unittests):

def _function_gen(
func: Callable[..., Any],
context: dict[str, Any],
signature: str,
lm: language_model.LanguageModel,
num_retries: int = 1,
Expand Down Expand Up @@ -141,32 +142,37 @@ def calculate_area_circle(radius: float) -> float:
elif isinstance(unittest, list):
unittest_examples = unittest

last_error = None
for _ in range(num_retries):
try:
source_code = prompting.query(
PythonFunctionPrompt(signature=signature), lm=lm
)
f = python.evaluate(source_code)
f = python.evaluate(source_code, global_vars=context)

# Check whether the sigantures are the same.
if inspect.signature(f) != inspect.signature(func):
pg.logging.warning(
"Signature mismatch. Expected: %s, Actual: %s",
inspect.signature(func),
inspect.signature(f),
raise python.CodeError(
code=source_code,
cause=TypeError(
f"Signature mismatch: Expected: {inspect.signature(func)}, "
f"Actual: {inspect.signature(f)}.",
),
)
continue

if callable(unittest):
unittest(f)
elif unittest_examples:
unittest_with_test_cases(f, unittest_examples)

return f, source_code
except Exception: # pylint: disable=broad-exception-caught
pass

return None, None
except python.CodeError as e:
last_error = e
pg.logging.warning(
f"Bad code generated: {e}",
)
print(last_error)
raise last_error


def _process_signature(signature):
Expand Down Expand Up @@ -220,6 +226,13 @@ def _decorate(func):
setattr(func, "__function__", None)
setattr(func, "__source_code__", None)

# Prepare the globals/locals for the generated code to be evaluated against.
callstack = inspect.stack()
assert len(callstack) > 1
context = dict(callstack[1][0].f_globals)
context.update(callstack[1][0].f_locals)
context.pop(func.__name__, None)

@functools.wraps(func)
def lm_generated_func(*args, **kwargs):
if func.__function__ is not None:
Expand All @@ -238,20 +251,20 @@ def lm_generated_func(*args, **kwargs):

if signature in cache:
func.__source_code__ = cache[signature]
func.__function__ = python.evaluate(func.__source_code__)
func.__function__ = python.evaluate(
func.__source_code__, global_vars=context
)
return func.__function__(*args, **kwargs)

func.__function__, func.__source_code__ = _function_gen(
func,
context,
signature,
lm,
num_retries=num_retries,
unittest=unittest,
unittest_num_retries=unittest_num_retries,
)
if func.__function__ is None:
raise ValueError(f"Function generation failed. Signature:\n{signature}")

if cache_filename is not None:
cache[signature] = func.__source_code__
cache.save(cache_filename)
Expand Down
30 changes: 30 additions & 0 deletions langfun/core/structured/function_generation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,36 @@ def linear_search(items, target): # pylint: disable=unused-argument

self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)

def test_context_passthrough(self):

class Number(pg.Object):
value: int

function_gen_lm_response = inspect.cleandoc("""
```python
def add(a: Number, b: Number) -> Number:
\"\"\"Adds two numbers together.\"\"\"
return Number(a.value + b.value)
```
""")

lm = fake.StaticSequence(
[function_gen_lm_response]
)

def _unittest_fn(func):
assert func(Number(1), Number(2)) == Number(3)

custom_unittest = _unittest_fn

@function_generation.function_gen(
lm=lm, unittest=custom_unittest, num_retries=1
)
def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
"""Adds two numbers together."""

self.assertEqual(add(Number(2), Number(3)), Number(5))

def test_siganture_check(self):
incorrect_signature_lm_response = inspect.cleandoc("""
```python
Expand Down

0 comments on commit 074e47a

Please sign in to comment.