diff --git a/langfun/core/structured/parsing.py b/langfun/core/structured/parsing.py index 50b418f..d54c23a 100644 --- a/langfun/core/structured/parsing.py +++ b/langfun/core/structured/parsing.py @@ -270,24 +270,31 @@ def call( if schema in (str, None): return lm_output if returns_message else lm_output.text + def _chain_nl_output_message(parsing_message: lf.Message): + """Chain the source of the parsed output to the LM output.""" + parsing_message.root.source = lm_output + parsing_message.tag('parsing-lm-output') + parsing_message.lm_input.tag('parsing-lm-input') + # Call `parsing_lm` for structured parsing. - parsing_message = querying.query( - lm_output.text, - schema, - examples=parsing_examples, - lm=parsing_lm or lm, - include_context=parsing_include_context, - cache_seed=cache_seed, - autofix=autofix, - autofix_lm=autofix_lm or lm, - protocol=protocol, - returns_message=True, - **kwargs, - ) - # Chain the source of the parsed output to the LM output. - parsing_message.root.source = lm_output - parsing_message.tag('parsing-lm-output') - parsing_message.lm_input.tag('parsing-lm-input') + try: + parsing_message = querying.query( + lm_output.text, + schema, + examples=parsing_examples, + lm=parsing_lm or lm, + include_context=parsing_include_context, + cache_seed=cache_seed, + autofix=autofix, + autofix_lm=autofix_lm or lm, + protocol=protocol, + returns_message=True, + **kwargs, + ) + _chain_nl_output_message(parsing_message) + except mapping.MappingError as e: + _chain_nl_output_message(e.lm_response) + raise e return parsing_message if returns_message else parsing_message.result diff --git a/langfun/core/structured/parsing_test.py b/langfun/core/structured/parsing_test.py index 802d0bd..a284185 100644 --- a/langfun/core/structured/parsing_test.py +++ b/langfun/core/structured/parsing_test.py @@ -686,6 +686,31 @@ def test_call_with_parsing_message_chaining(self): ], returns_message=True, ) + self.assertIn('parsing-lm-output', output.tags) + self.assertIn('parsing-lm-input', output.source.tags) + self.assertEqual(output.root.text, 'Compute 1 + 2') + + def test_call_with_parsing_message_chaining_on_parsing_error(self): + try: + output = parsing.call( + 'Compute 1 + 2', + int, + lm=fake.StaticSequence(['three']), + parsing_lm=fake.StaticSequence(['abc']), + parsing_examples=[ + mapping.MappingExample( + context='Multiple four and five', + input='twenty', + schema=int, + output=20, + ) + ], + returns_message=True, + ) + except mapping.MappingError as e: + output = e.lm_response + self.assertIn('parsing-lm-output', output.tags) + self.assertIn('parsing-lm-input', output.source.tags) self.assertEqual(output.root.text, 'Compute 1 + 2') def test_call_with_autofix(self):