Skip to content

Commit

Permalink
Allow lf.call to accept both str and lf.Template (also `lf.Lang…
Browse files Browse the repository at this point in the history
…Func`) object as prompt.

Usage:
```python
# Call with constant string-type prompt.
lf.call('Compute one plus one', lm=lf.llms.Gpt35())
>> "two"

# Call with returning a structured (int) type.
lf.call('Compute one plus one', int, lm=lf.llms.Gpt35())
>> 2

# Call with a template string with variables.
lf.call('Compute {{x}} plus {{y}}', int,
        x='one', y='one', lm=lf.llms.Gpt35())
>> 2

# Call with an `lf.Template` object with variables.
lf.call(lf.Template('Compute {{x}} plus {{y}}', x=1),
        y=1, lm=lf.llms.Gpt35())
>> 2
```

PiperOrigin-RevId: 570096116
  • Loading branch information
daiyip authored and langfun authors committed Oct 2, 2023
1 parent c5a6a88 commit 26937c2
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 10 deletions.
34 changes: 30 additions & 4 deletions langfun/core/langfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,16 +511,31 @@ class LangFuncCallEvent(subscription.Event[LangFunc]):
lm_callstack: list[LangFunc]


def call(prompt: str, returns: Any = None, **kwargs) -> Any:
def call(
prompt: str | template_lib.Template,
returns: Any = None, **kwargs
) -> Any:
"""Call a language model with prompt and formulate response in return type.
Examples::
# Call with constant string-type prompt.
lf.call('Compute one plus one', lm=lf.llms.Gpt35())
# Returns "two".
>> "two"
# Call with returning a structured (int) type.
lf.call('Compute one plus one', int, lm=lf.llms.Gpt35())
# Returns 2.
>> 2
# Call with a template string with variables.
lf.call('Compute {{x}} plus {{y}}', int,
x='one', y='one', lm=lf.llms.Gpt35())
>> 2
# Call with an `lf.Template` object with variables.
lf.call(lf.Template('Compute {{x}} plus {{y}}', x=1), int,
y=1, lm=lf.llms.Gpt35())
>> 2
Args:
prompt: User prompt that will be sent to LM, which could be a string or a
Expand All @@ -535,7 +550,18 @@ def call(prompt: str, returns: Any = None, **kwargs) -> Any:
Returns:
A string if `returns` is None or an instance of the return type.
"""
message = LangFunc(prompt, returns=returns)(**kwargs)
if isinstance(prompt, LangFunc):
lfun = prompt.as_structured(returns)
elif isinstance(prompt, template_lib.Template):
lfun = LangFunc(prompt.render(**kwargs).text, returns=returns)
elif isinstance(prompt, str):
lfun = LangFunc(prompt, returns=returns)
else:
raise TypeError(
'`prompt` should be a string or an `lf.Template` object. '
f'Encountered {prompt!r}.')

message = lfun(**kwargs)
if returns is None:
return message.text
return message.result
48 changes: 44 additions & 4 deletions langfun/core/langfunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from langfun.core import message
from langfun.core import message_transform
from langfun.core import subscription
from langfun.core import template as template_lib
from langfun.core.langfunc import call
from langfun.core.langfunc import LangFunc
from langfun.core.langfunc import LangFuncCallEvent
Expand Down Expand Up @@ -373,13 +374,52 @@ def on_event(self, event: LangFuncCallEvent):

class CallTest(unittest.TestCase):

def test_call(self):
with component.context(lm=fake.StaticSequence(['three'])):
self.assertEqual(call('Compute 1 + 1'), 'three')
def test_call_with_const_str(self):
with component.context(lm=fake.StaticMapping({
'Compute 1 + 2': 'three',
})):
self.assertEqual(call('Compute 1 + 2'), 'three')

def test_call_with_template_str(self):
with component.context(lm=fake.StaticMapping({
'Compute 1 + 2': 'three',
})):
self.assertEqual(call('Compute {{x}} + {{y}}', x=1, y=2), 'three')

def test_call_with_explicit_template(self):
with component.context(lm=fake.StaticMapping({
'Compute 1 + 2': 'three',
})):
self.assertEqual(
call(template_lib.Template('Compute {{x}} + {{y}}', x=1, y=2)),
'three')

with component.context(lm=fake.StaticMapping({
'Compute 1 + 2': 'three',
})):
self.assertEqual(
call(template_lib.Template('Compute {{x}} + {{y}}'), x=1, y=2),
'three')

def test_call_with_lfun(self):
with component.context(lm=fake.StaticMapping({
'Compute 1 + 2': 'three',
})):
self.assertEqual(
call(LangFunc('Compute {{x}} + {{y}}', x=1, y=2)),
'three')

def test_call_with_returns(self):
with component.context(lm=fake.StaticSequence(['three', '3'])):
self.assertEqual(call('Compute 1 + 1', returns=int), 3)
self.assertEqual(call('Compute 1 + 2', returns=int), 3)

with component.context(lm=fake.StaticSequence(['three', '3'])):
self.assertEqual(
call(LangFunc('Compute {{x}} + {{y}}', x=1, y=2), returns=int), 3)

def test_bad_call(self):
with self.assertRaisesRegex(TypeError, '`prompt` should be .*'):
call(1)


if __name__ == '__main__':
Expand Down
7 changes: 5 additions & 2 deletions langfun/core/structured/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _parse_structure_cls(

def as_structured(
self,
annotation: Union[Type[Any], list[Type[Any]], dict[str, Any]],
annotation: Union[Type[Any], list[Type[Any]], dict[str, Any], None],
default: Any = lf.message_transform.RAISE_IF_HAS_ERROR,
examples: list[mapping.Mapping] | None = None,
*,
Expand All @@ -194,7 +194,8 @@ def as_structured(
Args:
self: The Message transform object.
annotation: The annotation used for representing the structured output. E.g.
int, list[int], {'x': int, 'y': str}, A.
int, list[int], {'x': int, 'y': str}, A. If None, the return value will be
the original LM response (str).
default: The default value to use if parsing failed. If not specified, error
will be raised.
examples: An optional list of fewshot examples for helping parsing. If None,
Expand All @@ -207,6 +208,8 @@ def as_structured(
Returns:
The structured output according to the annotation.
"""
if annotation is None:
return self
if examples is None:
examples = _default_parsing_examples()
return self >> _parse_structure_cls(protocol)(
Expand Down

0 comments on commit 26937c2

Please sign in to comment.