From fafbf1dd3372e892f775417f4f455935fb3ddb0d Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Sun, 1 Oct 2023 06:22:35 -0700 Subject: [PATCH] Introduce `lf.call` as a convenient helper for directly invoking `lf.LangFunc` with possible structured output. Usage: ``` lf.call('Compute one plus one', lm=lf.llms.Gpt35()) # Returns "two". lf.call('Compute one plus one', int, lm=lf.llms.Gpt35()) # Returns 2. ``` PiperOrigin-RevId: 569859168 --- langfun/core/__init__.py | 3 +++ langfun/core/langfunc.py | 30 +++++++++++++++++++++++++++ langfun/core/langfunc_test.py | 39 +++++++++++++++++++---------------- 3 files changed, 54 insertions(+), 18 deletions(-) diff --git a/langfun/core/__init__.py b/langfun/core/__init__.py index 2387742c..5d3e3f90 100644 --- a/langfun/core/__init__.py +++ b/langfun/core/__init__.py @@ -27,6 +27,9 @@ from langfun.core.template import Template from langfun.core.langfunc import LangFunc +# Function for convenient call LM with possible structured output. +from langfun.core.langfunc import call + # Decorator for set the positional init args for component. from langfun.core.component import use_init_args diff --git a/langfun/core/langfunc.py b/langfun/core/langfunc.py index 19f535be..7307fd86 100644 --- a/langfun/core/langfunc.py +++ b/langfun/core/langfunc.py @@ -509,3 +509,33 @@ class LangFuncCallEvent(subscription.Event[LangFunc]): lm_input: message_lib.Message lm_output: message_lib.Message lm_callstack: list[LangFunc] + + +def call(prompt: str, returns: Any = None, **kwargs) -> Any: + """Call a language model with prompt and formulate response in return type. + + Examples:: + + lf.call('Compute one plus one', lm=lf.llms.Gpt35()) + # Returns "two". + + lf.call('Compute one plus one', int, lm=lf.llms.Gpt35()) + # Returns 2. + + Args: + prompt: User prompt that will be sent to LM, which could be a string or a + string template whose variables are provided from **kwargs. + returns: Type annotations for return type. If None, the raw LM response will + be returned (str). Otherwise, the response will be parsed based on the + return type. + **kwargs: Keyword arguments. Including options that control the calling + behavior, such as `lm`, `temperature`, etc. As well as variables that will + be fed to the prompt if it's a string template. + + Returns: + A string if `returns` is None or an instance of the return type. + """ + message = LangFunc(prompt, returns=returns)(**kwargs) + if returns is None: + return message.text + return message.result diff --git a/langfun/core/langfunc_test.py b/langfun/core/langfunc_test.py index 82c0cc8b..77074afb 100644 --- a/langfun/core/langfunc_test.py +++ b/langfun/core/langfunc_test.py @@ -19,9 +19,12 @@ from langfun.core import message from langfun.core import message_transform from langfun.core import subscription +from langfun.core.langfunc import call from langfun.core.langfunc import LangFunc from langfun.core.langfunc import LangFuncCallEvent from langfun.core.llms import fake +# Enables as_structured() operation of LangFunc. +from langfun.core.structured import parsing # pylint: disable=unused-import import pyglove as pg @@ -69,7 +72,7 @@ def test_cached_lm_input_and_output(self): self.assertEqual(l.lm_output, 'Hello!!!') -class CallTest(unittest.TestCase): +class LangFuncCallTest(unittest.TestCase): def test_call(self): l = LangFunc('Hello', lm=ExcitedEchoer()) @@ -238,32 +241,21 @@ def test_call_with_overriden_lm_input(self): self.assertEqual(t(lm_input=message.UserMessage('Hi')), 'Hi!!!') def test_call_with_structured_output(self): - - class FakeParseStructured(message_transform.MessageTransform): - - def _transform_path(self, unused_message, input_path, value): - return int(str(value)) - - prev_as_structured = message_transform.MessageTransform.as_structured - message_transform.MessageTransform.as_structured = ( - lambda self, *args: self >> FakeParseStructured()) - l = LangFunc('Compute 1 + 2', returns=int) - with component.context(lm=fake.StaticMapping({ - 'Compute 1 + 2': '3', - })): + with component.context(lm=fake.StaticSequence([ + 'three', '3' + ])): r = l() self.assertEqual(r.result, 3) l = LangFunc('Compute 1 + 2', returns=int, output_transform=lambda x: '3') with component.context( - lm=fake.StaticMapping({ - 'Compute 1 + 2': 'three', - }) + lm=fake.StaticSequence([ + 'three', '3' + ]) ): r = l() self.assertEqual(r.result, 3) - message_transform.MessageTransform.as_structured = prev_as_structured class TransformTest(unittest.TestCase): @@ -379,5 +371,16 @@ def on_event(self, event: LangFuncCallEvent): ) +class CallTest(unittest.TestCase): + + def test_call(self): + with component.context(lm=fake.StaticSequence(['three'])): + self.assertEqual(call('Compute 1 + 1'), 'three') + + def test_call_with_returns(self): + with component.context(lm=fake.StaticSequence(['three', '3'])): + self.assertEqual(call('Compute 1 + 1', returns=int), 3) + + if __name__ == '__main__': unittest.main()