From 74ee960149386302f838d995c832d5d5e1e6afe8 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Thu, 19 Dec 2024 15:13:30 -0800 Subject: [PATCH] `lf.QueryInvocation.output` to return `lf.MappingError` when OOP failed. Previously accessing `output` for bad LLM response will raise. PiperOrigin-RevId: 708050827 --- langfun/core/structured/querying.py | 11 ++++++++++- langfun/core/structured/querying_test.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/langfun/core/structured/querying.py b/langfun/core/structured/querying.py index f7f0af0..d73c226 100644 --- a/langfun/core/structured/querying.py +++ b/langfun/core/structured/querying.py @@ -583,7 +583,16 @@ def lm_request(self) -> lf.Message: @functools.cached_property def output(self) -> Any: - return query_output(self.lm_response, self.schema) + """The output of `lf.query`. If it failed, returns the `MappingError`.""" + try: + return query_output(self.lm_response, self.schema) + except mapping.MappingError as e: + return e + + @property + def has_error(self) -> bool: + """Returns True if the query failed to generate a valid output.""" + return isinstance(self.output, BaseException) @property def elapse(self) -> float: diff --git a/langfun/core/structured/querying_test.py b/langfun/core/structured/querying_test.py index 8d32859..735ee41 100644 --- a/langfun/core/structured/querying_test.py +++ b/langfun/core/structured/querying_test.py @@ -1051,6 +1051,16 @@ def test_query(self): class QueryInvocationTest(unittest.TestCase): + def test_basics(self): + lm = fake.StaticSequence([ + 'Activity(description="hi"', + ]) + with querying.track_queries() as queries: + querying.query('foo', Activity, default=None, lm=lm) + + self.assertTrue(queries[0].has_error) + self.assertIsInstance(queries[0].output, mapping.MappingError) + def test_to_html(self): lm = fake.StaticSequence([ 'Activity(description="hi")',