Skip to content

Commit

Permalink
Simplify langfun by removing MessageTransform.
Browse files Browse the repository at this point in the history
The concept of MessageTransform adds a flavor of Tensorflow (constructing graphs) to langfun, which is unnecessary complex and is hard to debug. Plus, the effectiveness of structured parsing/prompting minimizes the need for a flexible message transform pipeline. This CL completely removes this concept, so LangFunc could be used as a regular Python function. Consequently, `as_structured` is also removed - users should always use `lf.call` with schema in such cases.

PiperOrigin-RevId: 575556622
  • Loading branch information
daiyip authored and langfun authors committed Oct 22, 2023
1 parent 70dd522 commit da1479f
Show file tree
Hide file tree
Showing 21 changed files with 192 additions and 1,848 deletions.
1 change: 0 additions & 1 deletion langfun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from langfun.core import eval # pylint: disable=redefined-builtin
from langfun.core import templates
from langfun.core import transforms
from langfun.core import coding

PythonCode = coding.PythonCode
Expand Down
8 changes: 3 additions & 5 deletions langfun/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
# pylint: disable=g-importing-member
# pylint: disable=g-import-not-at-top

# Constants
from langfun.core.component import RAISE_IF_HAS_ERROR

# Interface for all langfun components.
from langfun.core.component import Component

Expand Down Expand Up @@ -88,11 +91,6 @@
from langfun.core.message import SystemMessage
from langfun.core.message import MemoryRecord


# Message transforms.
from langfun.core.message_transform import MessageTransform


# Interfaces for languge models.
from langfun.core.language_model import LanguageModel
from langfun.core.language_model import LMSample
Expand Down
2 changes: 0 additions & 2 deletions langfun/core/coding/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,6 @@ def sum(x: int, y: int):
self.assertEqual(f(1, y=2), 3)
self.assertEqual(f(1, y=2, sandbox=False), 3)

@unittest.skip(
'coverage data collection failure due to terminated child process.')
def test_bad_code(self):
f = python.PythonFunction(
name='sum',
Expand Down
7 changes: 4 additions & 3 deletions langfun/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@
import pyglove as pg


_RAISE_IF_ATTR_NOT_AVAILABLE = (pg.MISSING_VALUE,)
# Default value marker that indicates to raise error.
RAISE_IF_HAS_ERROR = (pg.MISSING_VALUE,)


class Component(pg.Object):
"""Base class for langfun components."""

# Override __repr__ format to use inferred values when available.
__str_format_args__ = dict(
__repr_format_kwargs__ = dict(
compact=True,
use_inferred=True,
)

# Override __str__ format to use inferred values when available.
__str_format_args__ = dict(
__str_format_kwargs__ = dict(
compact=False,
verbose=False,
use_inferred=True,
Expand Down
8 changes: 5 additions & 3 deletions langfun/core/eval/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,13 @@ def test_query(self):
self.assertEqual(s.process(s.examples[0]), Solution(2))

# Test query with fewshot examples.
lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
lm = fake.StaticSequence(['Solution(final_answer=2)'])
s = eval_set(
'basic_test', 'call',
'basic_test',
'query',
schema_fn=answer_schema_with_fewshot_examples(),
lm=lm)
lm=lm,
)
m = s.process(s.examples[0], returns_message=True)
self.assertIn('The result of one plus two', m.lm_input.text)

Expand Down
202 changes: 18 additions & 184 deletions langfun/core/langfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
"""LangFunc: Language-based functions."""

import dataclasses
from typing import Annotated, Any, Type
from typing import Annotated

from langfun.core import component
from langfun.core import language_model
from langfun.core import message as message_lib
from langfun.core import message_transform
from langfun.core import subscription
from langfun.core import template as template_lib
import pyglove as pg
Expand All @@ -35,9 +34,6 @@
@pg.use_init_args(['template_str'])
class LangFunc(
template_lib.Template,
# LangFunc is also a component that transforms an message to another
# message, so we can chain it.
message_transform.MessageTransform,
):
r"""Base class for natural-language driven component.
Expand Down Expand Up @@ -191,28 +187,6 @@ class Chat(lt.LangFunc):
),
] = component.contextual()

input_transform: Annotated[
message_transform.MessageTransform | None,
(
'External input transform, which intercepts LM input before calling '
'the internal `transform_input` method. It is designed to apply '
'extra structures to the LM input (e.g. COT).'
'We set the default value to None as we do not want the child '
"LangFun to use the parent's transform accidentally."
),
] = None

output_transform: Annotated[
message_transform.MessageTransform | None,
(
'Extenral output transform, which intercepts LM response before '
'calling the internal `transform_output` method. It is designed to '
'clean up LM response before structured parsing. We set the default '
'value to None as we do not want the child LangFun to use the '
"parent's transform accidentally."
),
] = None

def _on_bound(self):
super()._on_bound()

Expand All @@ -236,9 +210,6 @@ def __call__(
lm: language_model.LanguageModel | None = None,
lm_input: message_lib.Message | None = None,
cache_seed: int | None = 0,
skip_input_transform: bool = False,
skip_lm: bool = False,
skip_output_transform: bool = False,
**variables,
) -> message_lib.Message:
"""Calls language model with `lm_input` or rendered text.
Expand All @@ -252,10 +223,6 @@ def __call__(
cache_seed: Seed for computing cache key. The cache key is determined by a
tuple of (lm, prompt, cache seed). If None, cache will be disabled for
the query even cache is configured by the LM.
skip_input_transform: If True, the input transform will be skipped.
skip_lm: If True, skipping LM. In such case, the input message will be
returned.
skip_output_transform: If True, the output transform will be skipped.
**variables: Template variables applicable to this or child LangFunc.
Returns:
Expand All @@ -265,9 +232,6 @@ def __call__(
lm=lm,
lm_input=lm_input,
cache_seed=cache_seed,
skip_input_transform=skip_input_transform,
skip_lm=skip_lm,
skip_output_transform=skip_output_transform,
**variables,
)

Expand All @@ -277,9 +241,6 @@ def _call_once(
lm: language_model.LanguageModel | None = None,
lm_input: message_lib.Message | None = None,
cache_seed: int | None = 0,
skip_input_transform: bool = False,
skip_lm: bool = False,
skip_output_transform: bool = False,
**variables,
) -> message_lib.Message:
"""Call the language model once, with invoking the output transform."""
Expand All @@ -293,36 +254,27 @@ def _call_once(
with self.override(**kwargs):
# Render the LM input text and creates a user message.
if lm_input is None:
lm_input = self.render(
skip_input_transform=skip_input_transform, **kwargs)
self._cached_lm_input = lm_input

if not skip_lm:
# Send rendered text to LM.
lm_input.tag(message_lib.Message.TAG_LM_INPUT)
lm_output = self.lm(lm_input, cache_seed=cache_seed)

# Track the input as the source of the output.
lm_output.source = lm_input
lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
lm_input = self.render(**kwargs)

# Transform the output message if applicable.
if not skip_output_transform:
# Transform the input message.
lm_input = self.transform_input(lm_input)
self._cached_lm_input = lm_input

# Call the external output transform first to clean up LM response.
if self.output_transform is not None:
lm_output = self.output_transform.transform(lm_output)
# Send rendered text to LM.
lm_input.tag(message_lib.Message.TAG_LM_INPUT)
lm_output = self.lm(lm_input, cache_seed=cache_seed)

lm_output = self.transform_output(lm_output)
# Track the input as the source of the output.
lm_output.source = lm_input
lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)

lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
# Transform the output message.
lm_output = self.transform_output(lm_output)
lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)

# We cache the transformed output instead of the original one
# since the old one is tracked with `sym_origin`.
self._cached_lm_output = lm_output
else:
lm_output = lm_input
self._cached_lm_output = None
# We cache the transformed output instead of the original one
# since the old one is tracked with `sym_origin`.
self._cached_lm_output = lm_output

# Emit LangFuncCallEvent.
lm_callstack = list(
Expand All @@ -341,65 +293,8 @@ def _call_once(
top = pg.object_utils.thread_local_pop(_TLS_LFUN_CALL_STACK, self)
assert top is self, (top, self)

def render(
self,
*,
allow_partial: bool = False,
implicit: bool = False,
skip_input_transform: bool = False,
message_cls: Type[message_lib.Message] = message_lib.UserMessage,
**kwargs
) -> message_lib.Message:
"""Renders the template with variables from the context.
Args:
allow_partial: Allow partial rendering, this means that unresolved
variables are allowed and remain in the output text.
implicit: If True, reuse the rendering output if a parent LangFunc
is rendering current LangFunc multiple times. This is important
for making sure all references to the same LangFunc within a single
top-level rendering would return the same result. If False, every call
to `render` will trigger the actual rendering process.
skip_input_transform: If True, the input transform will be skipped.
message_cls: The message class used for creating the return value.
**kwargs: Values for template variables.
Returns:
An Message object as the rendered result.
"""
render_output = super().render(
allow_partial=allow_partial,
implicit=implicit,
message_cls=message_cls,
**kwargs,
)

# Transform the input message if applicable.
if not skip_input_transform:
# Call the external input transform first.
render_transformed = render_output
if self.input_transform is not None:
render_transformed = self.input_transform.transform(render_transformed)
render_transformed = self.transform_input(render_transformed)

if render_transformed is render_output and isinstance(
render_transformed.result, str
):
render_transformed = render_output.clone(
override={
'text': render_output.result,
'tags': [],
'metadata.result': pg.MISSING_VALUE,
}
)
render_transformed.source = render_output
render_transformed.tag(message_lib.Message.TAG_TRANSFORMED)

render_output = render_transformed
return render_output

#
# Internal input and output transforms.
# Input and output transforms.
# Subclasses can override.
#

Expand All @@ -413,67 +308,6 @@ def transform_output(
"""Transforms the output message before returning from __call__."""
return lm_output

#
# Override MessageTransform methods.
#

def _transform_path(
self,
message: message_lib.Message,
input_path: str,
value: Any
) -> message_lib.Message:
"""Implements MessageTransform._transform_path."""
if input_path in (
message_lib.Message.PATH_TEXT, message_lib.Message.PATH_ROOT):
input_message = message
else:
if isinstance(value, message_lib.Message):
message.set(input_path, pg.MISSING_VALUE)
input_message = value
elif isinstance(value, str):
input_message = message.clone(override={
'text': value,
'tags': [message_lib.Message.TAG_TRANSFORMED],
f'metadata.{input_path}': pg.MISSING_VALUE
})
else:
raise TypeError(
f'Metadata {repr(input_path)} should be a string or '
f'a `lf.Message`. Encountered: {value!r}'
)
input_message.source = message

# For LangFunc that are used as transforms, its template could access the
# input via 'message'.
output_message = self(message=input_message)

# Trace back the source for the root.
output_message.root.source = input_message
return output_message

def __rshift__(self, x):
"""Override >> to chain output transform and return self."""
self.rebind(
output_transform=(
self.output_transform >> message_transform.make_transform(x)),
skip_notification=True
)
return self

#
# Implements NaturalLanguageFormattable
#

def __repr__(self) -> str:
exclude_keys = []
if self.input_path is None:
exclude_keys.append('input_path')
if self.output_path is None:
exclude_keys.append('output_path')
return self.format(
compact=True, use_inferred=True, exclude_keys=exclude_keys)


# Register converter from str to LangFunc, therefore we can always
# pass strs to attributes that accept LangFunc.
Expand Down
Loading

0 comments on commit da1479f

Please sign in to comment.