Skip to content

Commit

Permalink
Improve ChatAdapter, introduce JsonAdapter, add default retries with …
Browse files Browse the repository at this point in the history
…the latter. (stanfordnlp#1700)

* Improve ChatAdapter's handling of typed values and Pydantic models

* Fixes for Literal

* Fixes for formatting complex-typed values

* Improve ChatAdapter, introduce JsonAdapter, add default retries with the latter.

* Minor fixes

* Update lock file

* Updates for json retries
  • Loading branch information
okhat authored Oct 26, 2024
1 parent 16ba98a commit 9f8c26d
Show file tree
Hide file tree
Showing 12 changed files with 1,998 additions and 1,421 deletions.
3 changes: 2 additions & 1 deletion dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,5 @@

# TODO: Consider if this should access settings.lm *or* a list that's shared across all LMs in the program.
def inspect_history(*args, **kwargs):
return settings.lm.inspect_history(*args, **kwargs)
from dspy.clients.lm import GLOBAL_HISTORY, _inspect_history
return _inspect_history(GLOBAL_HISTORY, *args, **kwargs)
3 changes: 2 additions & 1 deletion dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import Adapter
from .chat_adapter import ChatAdapter
from .chat_adapter import ChatAdapter
from .json_adapter import JsonAdapter
25 changes: 16 additions & 9 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,22 @@ def __init_subclass__(cls, **kwargs) -> None:
cls.parse = with_callbacks(cls.parse)

def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
inputs = self.format(signature, demos, inputs)
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)
inputs_ = self.format(signature, demos, inputs)
inputs_ = dict(prompt=inputs_) if isinstance(inputs_, str) else dict(messages=inputs_)

outputs = lm(**inputs, **lm_kwargs)
outputs = lm(**inputs_, **lm_kwargs)
values = []

for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
values.append(value)

return values
try:
for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
values.append(value)
return values

except Exception as e:
from .json_adapter import JsonAdapter
if _parse_values and not isinstance(self, JsonAdapter):
return JsonAdapter()(lm, lm_kwargs, signature, demos, inputs, _parse_values=_parse_values)
raise e

16 changes: 15 additions & 1 deletion dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import textwrap

from pydantic import TypeAdapter
from collections.abc import Mapping
from pydantic.fields import FieldInfo
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

Expand Down Expand Up @@ -269,6 +270,19 @@ def enumerate_fields(fields):
return "\n".join(parts).strip()


def move_type_to_front(d):
# Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence.
if isinstance(d, Mapping):
return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != 'type', item[0]))}
elif isinstance(d, list):
return [move_type_to_front(item) for item in d]
return d

def prepare_schema(type_):
schema = pydantic.TypeAdapter(type_).json_schema()
schema = move_type_to_front(schema)
return schema

def prepare_instructions(signature: SignatureMeta):
parts = []
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
Expand All @@ -290,7 +304,7 @@ def field_metadata(field_name, field_info):
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
else:
desc = "must be pareseable according to the following JSON schema: "
desc += json.dumps(pydantic.TypeAdapter(type_).json_schema())
desc += json.dumps(prepare_schema(type_))

desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
return f"{{{field_name}}}{desc}"
Expand Down
Loading

0 comments on commit 9f8c26d

Please sign in to comment.