Skip to content

Commit

Permalink
Introduce lf.call as a convenient helper for directly invoking `lf.…
Browse files Browse the repository at this point in the history
…LangFunc` with possible structured output.

Usage:
```
lf.call('Compute one plus one', lm=lf.llms.Gpt35())
# Returns "two".

lf.call('Compute one plus one', int, lm=lf.llms.Gpt35())
# Returns 2.
```

PiperOrigin-RevId: 569859168
  • Loading branch information
daiyip authored and langfun authors committed Oct 1, 2023
1 parent 38f0a23 commit 8e3ac51
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 18 deletions.
3 changes: 3 additions & 0 deletions langfun/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from langfun.core.template import Template
from langfun.core.langfunc import LangFunc

# Function for convenient call LM with possible structured output.
from langfun.core.langfunc import call

# Decorator for set the positional init args for component.
from langfun.core.component import use_init_args

Expand Down
30 changes: 30 additions & 0 deletions langfun/core/langfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,33 @@ class LangFuncCallEvent(subscription.Event[LangFunc]):
lm_input: message_lib.Message
lm_output: message_lib.Message
lm_callstack: list[LangFunc]


def call(prompt: str, returns: Any = None, **kwargs) -> Any:
"""Call a language model with prompt and formulate response in return type.
Examples::
lf.call('Compute one plus one', lm=lf.llms.Gpt35())
# Returns "two".
lf.call('Compute one plus one', int, lm=lf.llms.Gpt35())
# Returns 2.
Args:
prompt: User prompt that will be sent to LM, which could be a string or a
string template whose variables are provided from **kwargs.
returns: Type annotations for return type. If None, the raw LM response will
be returned (str). Otherwise, the response will be parsed based on the
return type.
**kwargs: Keyword arguments. Including options that control the calling
behavior, such as `lm`, `temperature`, etc. As well as variables that will
be fed to the prompt if it's a string template.
Returns:
A string if `returns` is None or an instance of the return type.
"""
message = LangFunc(prompt, returns=returns)(**kwargs)
if returns is None:
return message.text
return message.result
39 changes: 21 additions & 18 deletions langfun/core/langfunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from langfun.core import message
from langfun.core import message_transform
from langfun.core import subscription
from langfun.core.langfunc import call
from langfun.core.langfunc import LangFunc
from langfun.core.langfunc import LangFuncCallEvent
from langfun.core.llms import fake
# Enables as_structured() operation of LangFunc.
from langfun.core.structured import parsing # pylint: disable=unused-import
import pyglove as pg


Expand Down Expand Up @@ -69,7 +72,7 @@ def test_cached_lm_input_and_output(self):
self.assertEqual(l.lm_output, 'Hello!!!')


class CallTest(unittest.TestCase):
class LangFuncCallTest(unittest.TestCase):

def test_call(self):
l = LangFunc('Hello', lm=ExcitedEchoer())
Expand Down Expand Up @@ -238,32 +241,21 @@ def test_call_with_overriden_lm_input(self):
self.assertEqual(t(lm_input=message.UserMessage('Hi')), 'Hi!!!')

def test_call_with_structured_output(self):

class FakeParseStructured(message_transform.MessageTransform):

def _transform_path(self, unused_message, input_path, value):
return int(str(value))

prev_as_structured = message_transform.MessageTransform.as_structured
message_transform.MessageTransform.as_structured = (
lambda self, *args: self >> FakeParseStructured())

l = LangFunc('Compute 1 + 2', returns=int)
with component.context(lm=fake.StaticMapping({
'Compute 1 + 2': '3',
})):
with component.context(lm=fake.StaticSequence([
'three', '3'
])):
r = l()
self.assertEqual(r.result, 3)

l = LangFunc('Compute 1 + 2', returns=int, output_transform=lambda x: '3')
with component.context(
lm=fake.StaticMapping({
'Compute 1 + 2': 'three',
})
lm=fake.StaticSequence([
'three', '3'
])
):
r = l()
self.assertEqual(r.result, 3)
message_transform.MessageTransform.as_structured = prev_as_structured


class TransformTest(unittest.TestCase):
Expand Down Expand Up @@ -379,5 +371,16 @@ def on_event(self, event: LangFuncCallEvent):
)


class CallTest(unittest.TestCase):

def test_call(self):
with component.context(lm=fake.StaticSequence(['three'])):
self.assertEqual(call('Compute 1 + 1'), 'three')

def test_call_with_returns(self):
with component.context(lm=fake.StaticSequence(['three', '3'])):
self.assertEqual(call('Compute 1 + 1', returns=int), 3)


if __name__ == '__main__':
unittest.main()

0 comments on commit 8e3ac51

Please sign in to comment.