diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index edb5e1870..a8ae380cd 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -1,26 +1,23 @@ -import re -from typing import Any, Union -from dsp.adapters.base_template import Field -from dspy.signatures.signature import Signature -from .base import Adapter -from .image_utils import encode_image, Image - import ast -import json import enum import inspect -import pydantic +import json +import re import textwrap +from collections.abc import Mapping from itertools import chain +from typing import Any, Dict, List, Literal, NamedTuple, Union, get_args, get_origin + +import pydantic from pydantic import TypeAdapter -from collections.abc import Mapping from pydantic.fields import FieldInfo -from typing import Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin +from dsp.adapters.base_template import Field from dspy.adapters.base import Adapter -from ..signatures.field import OutputField -from ..signatures.signature import SignatureMeta -from ..signatures.utils import get_dspy_field_type +from dspy.adapters.image_utils import Image, encode_image +from dspy.signatures.field import OutputField +from dspy.signatures.signature import Signature, SignatureMeta +from dspy.signatures.utils import get_dspy_field_type field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]") @@ -33,12 +30,15 @@ class FieldInfoWithName(NamedTuple): # Built-in field indicating that a chat turn has been completed. BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField()) + class ChatAdapter(Adapter): def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] # Extract demos where some of the output_fields are not filled in. - incomplete_demos = [demo for demo in demos if not all(k in demo and demo[k] is not None for k in signature.fields)] + incomplete_demos = [ + demo for demo in demos if not all(k in demo and demo[k] is not None for k in signature.fields) + ] complete_demos = [demo for demo in demos if demo not in incomplete_demos] # Filter out demos that don't have at least one input and one output field. incomplete_demos = [ @@ -99,6 +99,7 @@ def format_finetune_data(self, signature, demos, inputs, outputs): # Wrap the messages in a dictionary with a "messages" key return dict(messages=messages) + def format_turn(self, signature, values, role, incomplete=False): return format_turn(signature, values, role, incomplete) @@ -112,8 +113,7 @@ def format_fields(self, signature, values, role): } return format_fields(fields_with_values) - - + def format_blob(blob): if "\n" not in blob and "«" not in blob and "»" not in blob: @@ -139,6 +139,7 @@ def format_input_list_field_value(value: List[Any]) -> str: return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)]) + def _serialize_for_json(value): if isinstance(value, pydantic.BaseModel): return value.model_dump() @@ -149,6 +150,7 @@ def _serialize_for_json(value): else: return value + def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> Union[str, dict]: """ Formats the value of the specified field according to the field's DSPy type (input or output), @@ -171,7 +173,7 @@ def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> if assume_text: return string_value - elif (isinstance(value, Image) or field_info.annotation == Image): + elif isinstance(value, Image) or field_info.annotation == Image: # This validation should happen somewhere else # Safe to import PIL here because it's only imported when an image is actually being formatted try: @@ -193,7 +195,6 @@ def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> return {"type": "text", "text": string_value} - def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]: """ Formats the values of the specified fields according to the field's DSPy type (input or output), @@ -222,10 +223,11 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text= else: return output + def parse_value(value, annotation): if annotation is str: return str(value) - + parsed_value = value if isinstance(annotation, enum.EnumMeta): @@ -238,70 +240,85 @@ def parse_value(value, annotation): parsed_value = ast.literal_eval(value) except (ValueError, SyntaxError): parsed_value = value - + return TypeAdapter(annotation).validate_python(parsed_value) -def format_turn(signature, values, role, incomplete=False): - fields_to_collapse = [] +def format_turn(signature, values, role, incomplete=False): """ Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted so that it can instruct an LLM to generate responses conforming to the specified DSPy signature. Args: - signature: The DSPy signature to which future LLM responses should conform. - values: A dictionary mapping field names (from the DSPy signature) to corresponding values - that should be included in the message. - role: The role of the message, which can be either "user" or "assistant". - incomplete: If True, indicates that output field values are present in the set of specified - ``values``. If False, indicates that ``values`` only contains input field values. + signature: The DSPy signature to which future LLM responses should conform. + values: A dictionary mapping field names (from the DSPy signature) to corresponding values + that should be included in the message. + role: The role of the message, which can be either "user" or "assistant". + incomplete: If True, indicates that output field values are present in the set of specified + ``values``. If False, indicates that ``values`` only contains input field values. + Returns: - A chat message that can be appended to a chat thread. The message contains two string fields: - ``role`` ("user" or "assistant") and ``content`` (the message text). + A chat message that can be appended to a chat thread. The message contains two string fields: + ``role`` ("user" or "assistant") and ``content`` (the message text). """ + fields_to_collapse = [] content = [] if role == "user": - fields: Dict[str, FieldInfo] = signature.input_fields + fields = signature.input_fields if incomplete: - fields_to_collapse.append({"type": "text", "text": "This is an example of the task, though some input or output fields are not supplied."}) + fields_to_collapse.append( + { + "type": "text", + "text": "This is an example of the task, though some input or output fields are not supplied.", + } + ) else: - fields: Dict[str, FieldInfo] = signature.output_fields + fields = signature.output_fields # Add the built-in field indicating that the chat turn has been completed fields[BuiltInCompletedOutputFieldInfo.name] = BuiltInCompletedOutputFieldInfo.info values = {**values, BuiltInCompletedOutputFieldInfo.name: ""} - field_names: KeysView = fields.keys() + field_names = fields.keys() if not incomplete: if not set(values).issuperset(set(field_names)): raise ValueError(f"Expected {field_names} but got {values.keys()}") - - fields_to_collapse.extend(format_fields( - fields_with_values={ - FieldInfoWithName(name=field_name, info=field_info): values.get( - field_name, "Not supplied for this particular example." - ) - for field_name, field_info in fields.items() - }, - assume_text=False - )) + + fields_to_collapse.extend( + format_fields( + fields_with_values={ + FieldInfoWithName(name=field_name, info=field_info): values.get( + field_name, "Not supplied for this particular example." + ) + for field_name, field_info in fields.items() + }, + assume_text=False, + ) + ) if role == "user": output_fields = list(signature.output_fields.keys()) + def type_info(v): - return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \ - if v.annotation is not str else "" + return ( + f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" + if v.annotation is not str + else "" + ) + if output_fields: - fields_to_collapse.append({ - "type": "text", - "text": "Respond with the corresponding output fields, starting with the field " - + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) - + ", and then ending with the marker for `[[ ## completed ## ]]`." - }) - + fields_to_collapse.append( + { + "type": "text", + "text": "Respond with the corresponding output fields, starting with the field " + + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + + ", and then ending with the marker for `[[ ## completed ## ]]`.", + } + ) + # flatmap the list if any items are lists otherwise keep the item - flattened_list = list(chain.from_iterable( - item if isinstance(item, list) else [item] for item in fields_to_collapse - )) + flattened_list = list( + chain.from_iterable(item if isinstance(item, list) else [item] for item in fields_to_collapse) + ) if all(message.get("type", None) == "text" for message in flattened_list): content = "\n\n".join(message.get("text") for message in flattened_list) @@ -314,16 +331,16 @@ def type_info(v): if not collapsed_messages: collapsed_messages.append(item) continue - - # If current item is image, add to collapsed_messages + + # If the current item is image, add to collapsed_messages if item.get("type") == "image_url": if collapsed_messages[-1].get("type") == "text": collapsed_messages[-1]["text"] += "\n" collapsed_messages.append(item) - # If previous item is text and current item is text, append to previous item + # If the previous item is text and current item is text, append to the previous item elif collapsed_messages[-1].get("type") == "text": collapsed_messages[-1]["text"] += "\n\n" + item["text"] - # If previous item is not text(aka image), add current item as a new item + # If the previous item is not text(aka image), add the current item as a new item else: item["text"] = "\n\n" + item["text"] collapsed_messages.append(item) @@ -357,16 +374,18 @@ def enumerate_fields(fields: dict[str, Field]) -> str: def move_type_to_front(d): # Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence. if isinstance(d, Mapping): - return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != 'type', item[0]))} + return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != "type", item[0]))} elif isinstance(d, list): return [move_type_to_front(item) for item in d] return d + def prepare_schema(type_): schema = pydantic.TypeAdapter(type_).json_schema() schema = move_type_to_front(schema) return schema + def prepare_instructions(signature: SignatureMeta): parts = [] parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields)) @@ -374,21 +393,21 @@ def prepare_instructions(signature: SignatureMeta): parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") def field_metadata(field_name, field_info): - type_ = field_info.annotation + field_type = field_info.annotation - if get_dspy_field_type(field_info) == 'input' or type_ is str: + if get_dspy_field_type(field_info) == "input" or field_type is str: desc = "" - elif type_ is bool: + elif field_type is bool: desc = "must be True or False" - elif type_ in (int, float): - desc = f"must be a single {type_.__name__} value" - elif inspect.isclass(type_) and issubclass(type_, enum.Enum): - desc= f"must be one of: {'; '.join(type_.__members__)}" - elif hasattr(type_, '__origin__') and type_.__origin__ is Literal: - desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}" + elif field_type in (int, float): + desc = f"must be a single {field_type.__name__} value" + elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): + desc = f"must be one of: {'; '.join(field_type.__members__)}" + elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal: + desc = f"must be one of: {'; '.join([str(x) for x in field_type.__args__])}" else: desc = "must be pareseable according to the following JSON schema: " - desc += json.dumps(prepare_schema(type_), ensure_ascii=False) + desc += json.dumps(prepare_schema(field_type), ensure_ascii=False) desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else "" return f"{{{field_name}}}{desc}" @@ -399,7 +418,7 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]): FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info) for field_name, field_info in fields.items() }, - assume_text=True + assume_text=True, ) parts.append(format_signature_fields_for_instructions(signature.input_fields))