Skip to content

Commit

Permalink
Style fix for chat_adapter (stanfordnlp#1846)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub authored Nov 23, 2024
1 parent 1a577f8 commit 756b619
Showing 1 changed file with 92 additions and 73 deletions.
165 changes: 92 additions & 73 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
import re
from typing import Any, Union
from dsp.adapters.base_template import Field
from dspy.signatures.signature import Signature
from .base import Adapter
from .image_utils import encode_image, Image

import ast
import json
import enum
import inspect
import pydantic
import json
import re
import textwrap
from collections.abc import Mapping
from itertools import chain
from typing import Any, Dict, List, Literal, NamedTuple, Union, get_args, get_origin

import pydantic
from pydantic import TypeAdapter
from collections.abc import Mapping
from pydantic.fields import FieldInfo
from typing import Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

from dsp.adapters.base_template import Field
from dspy.adapters.base import Adapter
from ..signatures.field import OutputField
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type
from dspy.adapters.image_utils import Image, encode_image
from dspy.signatures.field import OutputField
from dspy.signatures.signature import Signature, SignatureMeta
from dspy.signatures.utils import get_dspy_field_type

field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")

Expand All @@ -33,12 +30,15 @@ class FieldInfoWithName(NamedTuple):
# Built-in field indicating that a chat turn has been completed.
BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField())


class ChatAdapter(Adapter):
def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []

# Extract demos where some of the output_fields are not filled in.
incomplete_demos = [demo for demo in demos if not all(k in demo and demo[k] is not None for k in signature.fields)]
incomplete_demos = [
demo for demo in demos if not all(k in demo and demo[k] is not None for k in signature.fields)
]
complete_demos = [demo for demo in demos if demo not in incomplete_demos]
# Filter out demos that don't have at least one input and one output field.
incomplete_demos = [
Expand Down Expand Up @@ -99,6 +99,7 @@ def format_finetune_data(self, signature, demos, inputs, outputs):

# Wrap the messages in a dictionary with a "messages" key
return dict(messages=messages)

def format_turn(self, signature, values, role, incomplete=False):
return format_turn(signature, values, role, incomplete)

Expand All @@ -112,8 +113,7 @@ def format_fields(self, signature, values, role):
}

return format_fields(fields_with_values)




def format_blob(blob):
if "\n" not in blob and "«" not in blob and "»" not in blob:
Expand All @@ -139,6 +139,7 @@ def format_input_list_field_value(value: List[Any]) -> str:

return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])


def _serialize_for_json(value):
if isinstance(value, pydantic.BaseModel):
return value.model_dump()
Expand All @@ -149,6 +150,7 @@ def _serialize_for_json(value):
else:
return value


def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> Union[str, dict]:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
Expand All @@ -171,7 +173,7 @@ def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) ->

if assume_text:
return string_value
elif (isinstance(value, Image) or field_info.annotation == Image):
elif isinstance(value, Image) or field_info.annotation == Image:
# This validation should happen somewhere else
# Safe to import PIL here because it's only imported when an image is actually being formatted
try:
Expand All @@ -193,7 +195,6 @@ def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) ->
return {"type": "text", "text": string_value}



def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
Expand Down Expand Up @@ -222,10 +223,11 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
else:
return output


def parse_value(value, annotation):
if annotation is str:
return str(value)

parsed_value = value

if isinstance(annotation, enum.EnumMeta):
Expand All @@ -238,70 +240,85 @@ def parse_value(value, annotation):
parsed_value = ast.literal_eval(value)
except (ValueError, SyntaxError):
parsed_value = value

return TypeAdapter(annotation).validate_python(parsed_value)


def format_turn(signature, values, role, incomplete=False):
fields_to_collapse = []
def format_turn(signature, values, role, incomplete=False):
"""
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
Args:
signature: The DSPy signature to which future LLM responses should conform.
values: A dictionary mapping field names (from the DSPy signature) to corresponding values
that should be included in the message.
role: The role of the message, which can be either "user" or "assistant".
incomplete: If True, indicates that output field values are present in the set of specified
``values``. If False, indicates that ``values`` only contains input field values.
signature: The DSPy signature to which future LLM responses should conform.
values: A dictionary mapping field names (from the DSPy signature) to corresponding values
that should be included in the message.
role: The role of the message, which can be either "user" or "assistant".
incomplete: If True, indicates that output field values are present in the set of specified
``values``. If False, indicates that ``values`` only contains input field values.
Returns:
A chat message that can be appended to a chat thread. The message contains two string fields:
``role`` ("user" or "assistant") and ``content`` (the message text).
A chat message that can be appended to a chat thread. The message contains two string fields:
``role`` ("user" or "assistant") and ``content`` (the message text).
"""
fields_to_collapse = []
content = []

if role == "user":
fields: Dict[str, FieldInfo] = signature.input_fields
fields = signature.input_fields
if incomplete:
fields_to_collapse.append({"type": "text", "text": "This is an example of the task, though some input or output fields are not supplied."})
fields_to_collapse.append(
{
"type": "text",
"text": "This is an example of the task, though some input or output fields are not supplied.",
}
)
else:
fields: Dict[str, FieldInfo] = signature.output_fields
fields = signature.output_fields
# Add the built-in field indicating that the chat turn has been completed
fields[BuiltInCompletedOutputFieldInfo.name] = BuiltInCompletedOutputFieldInfo.info
values = {**values, BuiltInCompletedOutputFieldInfo.name: ""}
field_names: KeysView = fields.keys()
field_names = fields.keys()
if not incomplete:
if not set(values).issuperset(set(field_names)):
raise ValueError(f"Expected {field_names} but got {values.keys()}")

fields_to_collapse.extend(format_fields(
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in fields.items()
},
assume_text=False
))

fields_to_collapse.extend(
format_fields(
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in fields.items()
},
assume_text=False,
)
)

if role == "user":
output_fields = list(signature.output_fields.keys())

def type_info(v):
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
if v.annotation is not str else ""
return (
f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
if v.annotation is not str
else ""
)

if output_fields:
fields_to_collapse.append({
"type": "text",
"text": "Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
+ ", and then ending with the marker for `[[ ## completed ## ]]`."
})

fields_to_collapse.append(
{
"type": "text",
"text": "Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
+ ", and then ending with the marker for `[[ ## completed ## ]]`.",
}
)

# flatmap the list if any items are lists otherwise keep the item
flattened_list = list(chain.from_iterable(
item if isinstance(item, list) else [item] for item in fields_to_collapse
))
flattened_list = list(
chain.from_iterable(item if isinstance(item, list) else [item] for item in fields_to_collapse)
)

if all(message.get("type", None) == "text" for message in flattened_list):
content = "\n\n".join(message.get("text") for message in flattened_list)
Expand All @@ -314,16 +331,16 @@ def type_info(v):
if not collapsed_messages:
collapsed_messages.append(item)
continue
# If current item is image, add to collapsed_messages

# If the current item is image, add to collapsed_messages
if item.get("type") == "image_url":
if collapsed_messages[-1].get("type") == "text":
collapsed_messages[-1]["text"] += "\n"
collapsed_messages.append(item)
# If previous item is text and current item is text, append to previous item
# If the previous item is text and current item is text, append to the previous item
elif collapsed_messages[-1].get("type") == "text":
collapsed_messages[-1]["text"] += "\n\n" + item["text"]
# If previous item is not text(aka image), add current item as a new item
# If the previous item is not text(aka image), add the current item as a new item
else:
item["text"] = "\n\n" + item["text"]
collapsed_messages.append(item)
Expand Down Expand Up @@ -357,38 +374,40 @@ def enumerate_fields(fields: dict[str, Field]) -> str:
def move_type_to_front(d):
# Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence.
if isinstance(d, Mapping):
return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != 'type', item[0]))}
return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != "type", item[0]))}
elif isinstance(d, list):
return [move_type_to_front(item) for item in d]
return d


def prepare_schema(type_):
schema = pydantic.TypeAdapter(type_).json_schema()
schema = move_type_to_front(schema)
return schema


def prepare_instructions(signature: SignatureMeta):
parts = []
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")

def field_metadata(field_name, field_info):
type_ = field_info.annotation
field_type = field_info.annotation

if get_dspy_field_type(field_info) == 'input' or type_ is str:
if get_dspy_field_type(field_info) == "input" or field_type is str:
desc = ""
elif type_ is bool:
elif field_type is bool:
desc = "must be True or False"
elif type_ in (int, float):
desc = f"must be a single {type_.__name__} value"
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
desc= f"must be one of: {'; '.join(type_.__members__)}"
elif hasattr(type_, '__origin__') and type_.__origin__ is Literal:
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
elif field_type in (int, float):
desc = f"must be a single {field_type.__name__} value"
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
desc = f"must be one of: {'; '.join(field_type.__members__)}"
elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal:
desc = f"must be one of: {'; '.join([str(x) for x in field_type.__args__])}"
else:
desc = "must be pareseable according to the following JSON schema: "
desc += json.dumps(prepare_schema(type_), ensure_ascii=False)
desc += json.dumps(prepare_schema(field_type), ensure_ascii=False)

desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
return f"{{{field_name}}}{desc}"
Expand All @@ -399,7 +418,7 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
for field_name, field_info in fields.items()
},
assume_text=True
assume_text=True,
)

parts.append(format_signature_fields_for_instructions(signature.input_fields))
Expand Down

0 comments on commit 756b619

Please sign in to comment.