Skip to content

Commit

Permalink
Chain parsing lm output with original lm output even when MappingErro…
Browse files Browse the repository at this point in the history
…r is encountered during parsing.

PiperOrigin-RevId: 707628509
  • Loading branch information
daiyip authored and langfun authors committed Dec 18, 2024
1 parent e0eab72 commit 66331b5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
41 changes: 24 additions & 17 deletions langfun/core/structured/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,24 +270,31 @@ def call(
if schema in (str, None):
return lm_output if returns_message else lm_output.text

def _chain_nl_output_message(parsing_message: lf.Message):
"""Chain the source of the parsed output to the LM output."""
parsing_message.root.source = lm_output
parsing_message.tag('parsing-lm-output')
parsing_message.lm_input.tag('parsing-lm-input')

# Call `parsing_lm` for structured parsing.
parsing_message = querying.query(
lm_output.text,
schema,
examples=parsing_examples,
lm=parsing_lm or lm,
include_context=parsing_include_context,
cache_seed=cache_seed,
autofix=autofix,
autofix_lm=autofix_lm or lm,
protocol=protocol,
returns_message=True,
**kwargs,
)
# Chain the source of the parsed output to the LM output.
parsing_message.root.source = lm_output
parsing_message.tag('parsing-lm-output')
parsing_message.lm_input.tag('parsing-lm-input')
try:
parsing_message = querying.query(
lm_output.text,
schema,
examples=parsing_examples,
lm=parsing_lm or lm,
include_context=parsing_include_context,
cache_seed=cache_seed,
autofix=autofix,
autofix_lm=autofix_lm or lm,
protocol=protocol,
returns_message=True,
**kwargs,
)
_chain_nl_output_message(parsing_message)
except mapping.MappingError as e:
_chain_nl_output_message(e.lm_response)
raise e
return parsing_message if returns_message else parsing_message.result


Expand Down
25 changes: 25 additions & 0 deletions langfun/core/structured/parsing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,31 @@ def test_call_with_parsing_message_chaining(self):
],
returns_message=True,
)
self.assertIn('parsing-lm-output', output.tags)
self.assertIn('parsing-lm-input', output.source.tags)
self.assertEqual(output.root.text, 'Compute 1 + 2')

def test_call_with_parsing_message_chaining_on_parsing_error(self):
try:
output = parsing.call(
'Compute 1 + 2',
int,
lm=fake.StaticSequence(['three']),
parsing_lm=fake.StaticSequence(['abc']),
parsing_examples=[
mapping.MappingExample(
context='Multiple four and five',
input='twenty',
schema=int,
output=20,
)
],
returns_message=True,
)
except mapping.MappingError as e:
output = e.lm_response
self.assertIn('parsing-lm-output', output.tags)
self.assertIn('parsing-lm-input', output.source.tags)
self.assertEqual(output.root.text, 'Compute 1 + 2')

def test_call_with_autofix(self):
Expand Down

0 comments on commit 66331b5

Please sign in to comment.