Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve lf.function_gen(). #356

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions langfun/core/structured/function_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import functools
import inspect
import re
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Literal, Optional, Tuple

from langfun.core import language_model
from langfun.core import template
Expand All @@ -25,7 +25,7 @@
import pyglove as pg


def unittest_gen(signature, lm, num_retries=10):
def unittest_gen(signature, lm, num_retries=1):
"""Generates unit tests for a python function signature."""

class UnitTest(pg.Object):
Expand Down Expand Up @@ -78,10 +78,13 @@ def _function_gen(
func: Callable[..., Any],
signature: str,
lm: language_model.LanguageModel,
num_retries: int = 10,
num_retries: int = 1,
unittest: Optional[
Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
Callable[[Callable[..., Any]], None]
| list[Tuple[Any, Any]]
| Literal["auto"]
] = None,
unittest_num_retries: int = 1,
):
"""Generates a python function with LLM and verify its quality with unit testing."""

Expand Down Expand Up @@ -131,9 +134,11 @@ def calculate_area_circle(radius: float) -> float:
"""

unittest_examples = None
if unittest is None:
unittest_examples = unittest_gen(signature, lm=lm)
elif not callable(unittest):
if unittest == "auto":
unittest_examples = unittest_gen(
signature, lm=lm, num_retries=unittest_num_retries
)
elif isinstance(unittest, list):
unittest_examples = unittest

for _ in range(num_retries):
Expand All @@ -145,11 +150,16 @@ def calculate_area_circle(radius: float) -> float:

# Check whether the sigantures are the same.
if inspect.signature(f) != inspect.signature(func):
pg.logging.warning(
"Signature mismatch. Expected: %s, Actual: %s",
inspect.signature(func),
inspect.signature(f),
)
continue

if callable(unittest):
unittest(f)
else:
elif unittest_examples:
unittest_with_test_cases(f, unittest_examples)

return f, source_code
Expand All @@ -172,10 +182,13 @@ def _process_signature(signature):
def function_gen(
lm: language_model.LanguageModel,
cache_filename: str | None = None,
num_retries: int = 10,
num_retries: int = 1,
unittest: Optional[
Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
Callable[[Callable[..., Any]], None]
| list[Tuple[Any, Any]]
| Literal["auto"]
] = None,
unittest_num_retries: int = 1,
):
"""A decorator for automating function generation using a language model.

Expand All @@ -192,9 +205,12 @@ def function_gen(
make to generate a suitable function implementation.
unittest: This optional parameter enables the definition of custom unit
tests. You can either provide a list of test cases as tuples of inputs
and outputs, or a function that throws an error if a test fails. If left
as None (the default setting), the LLM will automatically create the
unit test cases.
and outputs, or a function that throws an error if a test fails, or let
LLM automatically create the unit test cases. If a generated function is
and returned, it should pass all the unittests.
unittest_num_retries: If unittest is set to "auto", this parameter
specifies the number of times the LLM's attempts to generate unit test
cases.

Returns:
The implemented function object.
Expand Down Expand Up @@ -226,7 +242,12 @@ def lm_generated_func(*args, **kwargs):
return func.__function__(*args, **kwargs)

func.__function__, func.__source_code__ = _function_gen(
func, signature, lm, num_retries=num_retries, unittest=unittest
func,
signature,
lm,
num_retries=num_retries,
unittest=unittest,
unittest_num_retries=unittest_num_retries,
)
if func.__function__ is None:
raise ValueError(f"Function generation failed. Signature:\n{signature}")
Expand Down
44 changes: 42 additions & 2 deletions langfun/core/structured/function_generation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,42 @@ def linear_search(items, target):

lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])

@function_generation.function_gen(lm=lm, unittest='auto')
def linear_search(items, target): # pylint: disable=unused-argument
"""Performs a linear search on a list to find a target value.

Args:
items (list): The list to search within.
target: The value to search for.

Returns:
int: The index of the target value if found, otherwise -1.
"""

self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
self.assertEqual(linear_search.source(), function_gen_lm_response)

def test_generate_function_without_unittest(self):
function_gen_lm_response = inspect.cleandoc("""
def linear_search(items, target):
\"\"\"
Performs a linear search on a list to find a target value.

Args:
items (list): The list to search within.
target: The value to search for.

Returns:
int: The index of the target value if found, otherwise -1.
\"\"\"
for i, item in enumerate(items):
if item == target:
return i
return -1
""")

lm = fake.StaticSequence([function_gen_lm_response])

@function_generation.function_gen(lm=lm)
def linear_search(items, target): # pylint: disable=unused-argument
"""Performs a linear search on a list to find a target value.
Expand Down Expand Up @@ -258,7 +294,9 @@ def _unittest_fn(func):
cache_file = os.path.join(cache_file_dir, 'cache_file.json')

@function_generation.function_gen(
lm=lm, unittest=_unittest_fn, cache_filename=cache_file
lm=lm,
unittest=_unittest_fn,
cache_filename=cache_file,
)
def linear_search(items, target): # pylint: disable=unused-argument
"""Performs a linear search on a list to find a target value.
Expand Down Expand Up @@ -310,7 +348,9 @@ def _unittest_fn(func):

custom_unittest = _unittest_fn

@function_generation.function_gen(lm=lm, unittest=custom_unittest)
@function_generation.function_gen(
lm=lm, unittest=custom_unittest, num_retries=2
)
def linear_search(items, target): # pylint: disable=unused-argument
"""Performs a linear search on a list to find a target value.

Expand Down
Loading