From 421b8d5b302369bc5f75228a797f1e4123810a91 Mon Sep 17 00:00:00 2001 From: Yifeng Lu Date: Fri, 6 Dec 2024 13:58:53 -0800 Subject: [PATCH] Improve lf.function_gen(). 1. Improve the default behavior: num_retries=1 and no unittest. 2. Improve the meaning of unittest field. User should set to 'auto' to explicitely enable auto unittest. By default, the None value means no unittest. PiperOrigin-RevId: 703607882 --- .../core/structured/function_generation.py | 49 +++++++++++++------ .../structured/function_generation_test.py | 44 ++++++++++++++++- 2 files changed, 77 insertions(+), 16 deletions(-) diff --git a/langfun/core/structured/function_generation.py b/langfun/core/structured/function_generation.py index d00ca58..173081d 100644 --- a/langfun/core/structured/function_generation.py +++ b/langfun/core/structured/function_generation.py @@ -16,7 +16,7 @@ import functools import inspect import re -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Literal, Optional, Tuple from langfun.core import language_model from langfun.core import template @@ -25,7 +25,7 @@ import pyglove as pg -def unittest_gen(signature, lm, num_retries=10): +def unittest_gen(signature, lm, num_retries=1): """Generates unit tests for a python function signature.""" class UnitTest(pg.Object): @@ -78,10 +78,13 @@ def _function_gen( func: Callable[..., Any], signature: str, lm: language_model.LanguageModel, - num_retries: int = 10, + num_retries: int = 1, unittest: Optional[ - Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]] + Callable[[Callable[..., Any]], None] + | list[Tuple[Any, Any]] + | Literal["auto"] ] = None, + unittest_num_retries: int = 1, ): """Generates a python function with LLM and verify its quality with unit testing.""" @@ -131,9 +134,11 @@ def calculate_area_circle(radius: float) -> float: """ unittest_examples = None - if unittest is None: - unittest_examples = unittest_gen(signature, lm=lm) - elif not callable(unittest): + if unittest == "auto": + unittest_examples = unittest_gen( + signature, lm=lm, num_retries=unittest_num_retries + ) + elif isinstance(unittest, list): unittest_examples = unittest for _ in range(num_retries): @@ -145,11 +150,16 @@ def calculate_area_circle(radius: float) -> float: # Check whether the sigantures are the same. if inspect.signature(f) != inspect.signature(func): + pg.logging.warning( + "Signature mismatch. Expected: %s, Actual: %s", + inspect.signature(func), + inspect.signature(f), + ) continue if callable(unittest): unittest(f) - else: + elif unittest_examples: unittest_with_test_cases(f, unittest_examples) return f, source_code @@ -172,10 +182,13 @@ def _process_signature(signature): def function_gen( lm: language_model.LanguageModel, cache_filename: str | None = None, - num_retries: int = 10, + num_retries: int = 1, unittest: Optional[ - Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]] + Callable[[Callable[..., Any]], None] + | list[Tuple[Any, Any]] + | Literal["auto"] ] = None, + unittest_num_retries: int = 1, ): """A decorator for automating function generation using a language model. @@ -192,9 +205,12 @@ def function_gen( make to generate a suitable function implementation. unittest: This optional parameter enables the definition of custom unit tests. You can either provide a list of test cases as tuples of inputs - and outputs, or a function that throws an error if a test fails. If left - as None (the default setting), the LLM will automatically create the - unit test cases. + and outputs, or a function that throws an error if a test fails, or let + LLM automatically create the unit test cases. If a generated function is + and returned, it should pass all the unittests. + unittest_num_retries: If unittest is set to "auto", this parameter + specifies the number of times the LLM's attempts to generate unit test + cases. Returns: The implemented function object. @@ -226,7 +242,12 @@ def lm_generated_func(*args, **kwargs): return func.__function__(*args, **kwargs) func.__function__, func.__source_code__ = _function_gen( - func, signature, lm, num_retries=num_retries, unittest=unittest + func, + signature, + lm, + num_retries=num_retries, + unittest=unittest, + unittest_num_retries=unittest_num_retries, ) if func.__function__ is None: raise ValueError(f"Function generation failed. Signature:\n{signature}") diff --git a/langfun/core/structured/function_generation_test.py b/langfun/core/structured/function_generation_test.py index 22927c9..7cf4c6a 100644 --- a/langfun/core/structured/function_generation_test.py +++ b/langfun/core/structured/function_generation_test.py @@ -63,6 +63,42 @@ def linear_search(items, target): lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response]) + @function_generation.function_gen(lm=lm, unittest='auto') + def linear_search(items, target): # pylint: disable=unused-argument + """Performs a linear search on a list to find a target value. + + Args: + items (list): The list to search within. + target: The value to search for. + + Returns: + int: The index of the target value if found, otherwise -1. + """ + + self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2) + self.assertEqual(linear_search.source(), function_gen_lm_response) + + def test_generate_function_without_unittest(self): + function_gen_lm_response = inspect.cleandoc(""" + def linear_search(items, target): + \"\"\" + Performs a linear search on a list to find a target value. + + Args: + items (list): The list to search within. + target: The value to search for. + + Returns: + int: The index of the target value if found, otherwise -1. + \"\"\" + for i, item in enumerate(items): + if item == target: + return i + return -1 + """) + + lm = fake.StaticSequence([function_gen_lm_response]) + @function_generation.function_gen(lm=lm) def linear_search(items, target): # pylint: disable=unused-argument """Performs a linear search on a list to find a target value. @@ -258,7 +294,9 @@ def _unittest_fn(func): cache_file = os.path.join(cache_file_dir, 'cache_file.json') @function_generation.function_gen( - lm=lm, unittest=_unittest_fn, cache_filename=cache_file + lm=lm, + unittest=_unittest_fn, + cache_filename=cache_file, ) def linear_search(items, target): # pylint: disable=unused-argument """Performs a linear search on a list to find a target value. @@ -310,7 +348,9 @@ def _unittest_fn(func): custom_unittest = _unittest_fn - @function_generation.function_gen(lm=lm, unittest=custom_unittest) + @function_generation.function_gen( + lm=lm, unittest=custom_unittest, num_retries=2 + ) def linear_search(items, target): # pylint: disable=unused-argument """Performs a linear search on a list to find a target value.