From bcf76975b8eb28e296268a93ba8471759f3fffa2 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Sat, 7 Dec 2024 09:49:17 -0800 Subject: [PATCH] `lf.function_gen`: evaluating LM generated code with the symbols from the original function module. PiperOrigin-RevId: 703823251 --- .../core/structured/function_generation.py | 40 ++++++++++++------- .../structured/function_generation_test.py | 30 ++++++++++++++ 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/langfun/core/structured/function_generation.py b/langfun/core/structured/function_generation.py index 173081d..ef54380 100644 --- a/langfun/core/structured/function_generation.py +++ b/langfun/core/structured/function_generation.py @@ -76,6 +76,7 @@ def unittest_with_test_cases(f, unittests): def _function_gen( func: Callable[..., Any], + context: dict[str, Any], signature: str, lm: language_model.LanguageModel, num_retries: int = 1, @@ -141,21 +142,23 @@ def calculate_area_circle(radius: float) -> float: elif isinstance(unittest, list): unittest_examples = unittest + last_error = None for _ in range(num_retries): try: source_code = prompting.query( PythonFunctionPrompt(signature=signature), lm=lm ) - f = python.evaluate(source_code) + f = python.evaluate(source_code, global_vars=context) # Check whether the sigantures are the same. if inspect.signature(f) != inspect.signature(func): - pg.logging.warning( - "Signature mismatch. Expected: %s, Actual: %s", - inspect.signature(func), - inspect.signature(f), + raise python.CodeError( + code=source_code, + cause=TypeError( + f"Signature mismatch: Expected: {inspect.signature(func)}, " + f"Actual: {inspect.signature(f)}.", + ), ) - continue if callable(unittest): unittest(f) @@ -163,10 +166,12 @@ def calculate_area_circle(radius: float) -> float: unittest_with_test_cases(f, unittest_examples) return f, source_code - except Exception: # pylint: disable=broad-exception-caught - pass - - return None, None + except python.CodeError as e: + last_error = e + pg.logging.warning( + f"Bad code generated: {e}", + ) + raise last_error def _process_signature(signature): @@ -220,6 +225,13 @@ def _decorate(func): setattr(func, "__function__", None) setattr(func, "__source_code__", None) + # Prepare the globals/locals for the generated code to be evaluated against. + callstack = inspect.stack() + assert len(callstack) > 1 + context = dict(callstack[1][0].f_globals) + context.update(callstack[1][0].f_locals) + context.pop(func.__name__, None) + @functools.wraps(func) def lm_generated_func(*args, **kwargs): if func.__function__ is not None: @@ -238,20 +250,20 @@ def lm_generated_func(*args, **kwargs): if signature in cache: func.__source_code__ = cache[signature] - func.__function__ = python.evaluate(func.__source_code__) + func.__function__ = python.evaluate( + func.__source_code__, global_vars=context + ) return func.__function__(*args, **kwargs) func.__function__, func.__source_code__ = _function_gen( func, + context, signature, lm, num_retries=num_retries, unittest=unittest, unittest_num_retries=unittest_num_retries, ) - if func.__function__ is None: - raise ValueError(f"Function generation failed. Signature:\n{signature}") - if cache_filename is not None: cache[signature] = func.__source_code__ cache.save(cache_filename) diff --git a/langfun/core/structured/function_generation_test.py b/langfun/core/structured/function_generation_test.py index 7cf4c6a..f31d003 100644 --- a/langfun/core/structured/function_generation_test.py +++ b/langfun/core/structured/function_generation_test.py @@ -311,6 +311,36 @@ def linear_search(items, target): # pylint: disable=unused-argument self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2) + def test_context_passthrough(self): + + class Number(pg.Object): + value: int + + function_gen_lm_response = inspect.cleandoc(""" + ```python + def add(a: Number, b: Number) -> Number: + \"\"\"Adds two numbers together.\"\"\" + return Number(a.value + b.value) + ``` + """) + + lm = fake.StaticSequence( + [function_gen_lm_response] + ) + + def _unittest_fn(func): + assert func(Number(1), Number(2)) == Number(3) + + custom_unittest = _unittest_fn + + @function_generation.function_gen( + lm=lm, unittest=custom_unittest, num_retries=1 + ) + def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument + """Adds two numbers together.""" + + self.assertEqual(add(Number(2), Number(3)), Number(5)) + def test_siganture_check(self): incorrect_signature_lm_response = inspect.cleandoc(""" ```python