Skip to content

Commit

Permalink
Disable autofix as default for all Langfun APIs.
Browse files Browse the repository at this point in the history
I think we need to carefully revisit the effectiveness of autofix before enabling it as default.

PiperOrigin-RevId: 598459255
  • Loading branch information
yifenglou authored and langfun authors committed Jan 15, 2024
1 parent 64b027b commit faa4103
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 10 deletions.
2 changes: 1 addition & 1 deletion langfun/core/structured/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def complete(
*,
lm: lf.LanguageModel | None = None,
examples: list[mapping.MappingExample] | None = None,
autofix: int = 3,
autofix: int = 0,
autofix_lm: lf.LanguageModel | None = None,
returns_message: bool = False,
**kwargs,
Expand Down
4 changes: 3 additions & 1 deletion langfun/core/structured/completion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ class Solution(pg.Object):
debug=True,
)
self.assertEqual(
completion.complete(Solution.partial('Compute 1 + 1'), lm=lm),
completion.complete(
Solution.partial('Compute 1 + 1'), lm=lm, autofix=3
),
Solution(question='Compute 1 + 1', answer=2),
)

Expand Down
2 changes: 1 addition & 1 deletion langfun/core/structured/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def mapping_request(self) -> MappingExample:
'If 0 (default), there is no automatic correction. '
'This flag is effective only when the output needs to be structured.'
),
] = 3
] = 0

autofix_lm: Annotated[
lf.LanguageModel,
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/structured/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def parse(
user_prompt: str | None = None,
lm: lf.LanguageModel | None = None,
examples: list[mapping.MappingExample] | None = None,
autofix: int = 3,
autofix: int = 0,
autofix_lm: lf.LanguageModel | None = None,
protocol: schema_lib.SchemaProtocol = 'python',
returns_message: bool = False,
Expand Down Expand Up @@ -184,7 +184,7 @@ def call(
lm: lf.LanguageModel | None = None,
parsing_lm: lf.LanguageModel | None = None,
parsing_examples: list[mapping.MappingExample] | None = None,
autofix: int = 3,
autofix: int = 0,
autofix_lm: lf.LanguageModel | None = None,
response_postprocess: Callable[[str], str] | None = None,
protocol: schema_lib.SchemaProtocol = 'python',
Expand Down
6 changes: 4 additions & 2 deletions langfun/core/structured/parsing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_autofix(self):
)
"""),
])
self.assertEqual(parsing.parse('three', int, lm=lm), 3)
self.assertEqual(parsing.parse('three', int, lm=lm, autofix=3), 3)

def test_parse(self):
lm = fake.StaticResponse('1')
Expand Down Expand Up @@ -677,7 +677,9 @@ def test_call_with_autofix(self):
],
debug=True,
)
self.assertEqual(parsing.call('what is one plus two?', int, lm=lm), 3)
self.assertEqual(
parsing.call('what is one plus two?', int, lm=lm, autofix=3), 3
)

def test_call_with_structured_input(self):
self.assertEqual(parsing.call(1, lm=fake.StaticResponse('2')), '2')
Expand Down
2 changes: 1 addition & 1 deletion langfun/core/structured/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def query(
*,
lm: lf.LanguageModel | None = None,
examples: list[mapping.MappingExample] | None = None,
autofix: int = 3,
autofix: int = 0,
autofix_lm: lf.LanguageModel | None = None,
protocol: schema_lib.SchemaProtocol = 'python',
returns_message: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/structured/prompting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def test_bad_response(self):
coding.CodeError,
'name .* is not defined',
):
prompting.query('Compute 1 + 2', int, autofix=0)
prompting.query('Compute 1 + 2', int)

def test_autofix(self):
lm = fake.StaticSequence([
Expand All @@ -432,7 +432,7 @@ def test_autofix(self):
)
"""),
])
self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm), 1)
self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)


class QueryStructureJsonTest(unittest.TestCase):
Expand Down

0 comments on commit faa4103

Please sign in to comment.