Skip to content

Commit

Permalink
Improve lf.function_gen().
Browse files Browse the repository at this point in the history
1. Improve the default behavior: num_retries=1 and no unittest.
2. Improve the meaning of unittest field. User should set to 'auto' to explicitely enable auto unittest. By default, the None value means no unittest.

PiperOrigin-RevId: 703607882
  • Loading branch information
yifenglou authored and langfun authors committed Dec 6, 2024
1 parent 341eedc commit 421b8d5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 16 deletions.
49 changes: 35 additions & 14 deletions langfun/core/structured/function_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import functools
import inspect
import re
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Literal, Optional, Tuple

from langfun.core import language_model
from langfun.core import template
Expand All @@ -25,7 +25,7 @@
import pyglove as pg


def unittest_gen(signature, lm, num_retries=10):
def unittest_gen(signature, lm, num_retries=1):
"""Generates unit tests for a python function signature."""

class UnitTest(pg.Object):
Expand Down Expand Up @@ -78,10 +78,13 @@ def _function_gen(
func: Callable[..., Any],
signature: str,
lm: language_model.LanguageModel,
num_retries: int = 10,
num_retries: int = 1,
unittest: Optional[
Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
Callable[[Callable[..., Any]], None]
| list[Tuple[Any, Any]]
| Literal["auto"]
] = None,
unittest_num_retries: int = 1,
):
"""Generates a python function with LLM and verify its quality with unit testing."""

Expand Down Expand Up @@ -131,9 +134,11 @@ def calculate_area_circle(radius: float) -> float:
"""

unittest_examples = None
if unittest is None:
unittest_examples = unittest_gen(signature, lm=lm)
elif not callable(unittest):
if unittest == "auto":
unittest_examples = unittest_gen(
signature, lm=lm, num_retries=unittest_num_retries
)
elif isinstance(unittest, list):
unittest_examples = unittest

for _ in range(num_retries):
Expand All @@ -145,11 +150,16 @@ def calculate_area_circle(radius: float) -> float:

# Check whether the sigantures are the same.
if inspect.signature(f) != inspect.signature(func):
pg.logging.warning(
"Signature mismatch. Expected: %s, Actual: %s",
inspect.signature(func),
inspect.signature(f),
)
continue

if callable(unittest):
unittest(f)
else:
elif unittest_examples:
unittest_with_test_cases(f, unittest_examples)

return f, source_code
Expand All @@ -172,10 +182,13 @@ def _process_signature(signature):
def function_gen(
lm: language_model.LanguageModel,
cache_filename: str | None = None,
num_retries: int = 10,
num_retries: int = 1,
unittest: Optional[
Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
Callable[[Callable[..., Any]], None]
| list[Tuple[Any, Any]]
| Literal["auto"]
] = None,
unittest_num_retries: int = 1,
):
"""A decorator for automating function generation using a language model.
Expand All @@ -192,9 +205,12 @@ def function_gen(
make to generate a suitable function implementation.
unittest: This optional parameter enables the definition of custom unit
tests. You can either provide a list of test cases as tuples of inputs
and outputs, or a function that throws an error if a test fails. If left
as None (the default setting), the LLM will automatically create the
unit test cases.
and outputs, or a function that throws an error if a test fails, or let
LLM automatically create the unit test cases. If a generated function is
and returned, it should pass all the unittests.
unittest_num_retries: If unittest is set to "auto", this parameter
specifies the number of times the LLM's attempts to generate unit test
cases.
Returns:
The implemented function object.
Expand Down Expand Up @@ -226,7 +242,12 @@ def lm_generated_func(*args, **kwargs):
return func.__function__(*args, **kwargs)

func.__function__, func.__source_code__ = _function_gen(
func, signature, lm, num_retries=num_retries, unittest=unittest
func,
signature,
lm,
num_retries=num_retries,
unittest=unittest,
unittest_num_retries=unittest_num_retries,
)
if func.__function__ is None:
raise ValueError(f"Function generation failed. Signature:\n{signature}")
Expand Down
44 changes: 42 additions & 2 deletions langfun/core/structured/function_generation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,42 @@ def linear_search(items, target):

lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])

@function_generation.function_gen(lm=lm, unittest='auto')
def linear_search(items, target): # pylint: disable=unused-argument
"""Performs a linear search on a list to find a target value.
Args:
items (list): The list to search within.
target: The value to search for.
Returns:
int: The index of the target value if found, otherwise -1.
"""

self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
self.assertEqual(linear_search.source(), function_gen_lm_response)

def test_generate_function_without_unittest(self):
function_gen_lm_response = inspect.cleandoc("""
def linear_search(items, target):
\"\"\"
Performs a linear search on a list to find a target value.
Args:
items (list): The list to search within.
target: The value to search for.
Returns:
int: The index of the target value if found, otherwise -1.
\"\"\"
for i, item in enumerate(items):
if item == target:
return i
return -1
""")

lm = fake.StaticSequence([function_gen_lm_response])

@function_generation.function_gen(lm=lm)
def linear_search(items, target): # pylint: disable=unused-argument
"""Performs a linear search on a list to find a target value.
Expand Down Expand Up @@ -258,7 +294,9 @@ def _unittest_fn(func):
cache_file = os.path.join(cache_file_dir, 'cache_file.json')

@function_generation.function_gen(
lm=lm, unittest=_unittest_fn, cache_filename=cache_file
lm=lm,
unittest=_unittest_fn,
cache_filename=cache_file,
)
def linear_search(items, target): # pylint: disable=unused-argument
"""Performs a linear search on a list to find a target value.
Expand Down Expand Up @@ -310,7 +348,9 @@ def _unittest_fn(func):

custom_unittest = _unittest_fn

@function_generation.function_gen(lm=lm, unittest=custom_unittest)
@function_generation.function_gen(
lm=lm, unittest=custom_unittest, num_retries=2
)
def linear_search(items, target): # pylint: disable=unused-argument
"""Performs a linear search on a list to find a target value.
Expand Down

0 comments on commit 421b8d5

Please sign in to comment.