Skip to content

Commit

Permalink
Merge pull request stanfordnlp#1604 from chenmoneygithub/fix-litellm-…
Browse files Browse the repository at this point in the history
…client

Remove legacy code from client/lm.py and reformat
  • Loading branch information
okhat authored Oct 8, 2024
2 parents cc368f8 + 069808f commit 92d25a7
Showing 1 changed file with 41 additions and 35 deletions.
76 changes: 41 additions & 35 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,45 @@
import os
import uuid
import ujson
import functools
from pathlib import Path
import os
import uuid
from datetime import datetime
from pathlib import Path

try:
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
import litellm
litellm.telemetry = False
import litellm
import ujson
from litellm.caching import Cache

from litellm.caching import Cache
disk_cache_dir = os.environ.get('DSPY_CACHEDIR') or os.path.join(Path.home(), '.dspy_cache')
litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk")
disk_cache_dir = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk")
litellm.telemetry = False

except ImportError:
class LitellmPlaceholder:
def __getattr__(self, _): raise ImportError("The LiteLLM package is not installed. Run `pip install litellm`.")
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

litellm = LitellmPlaceholder()

class LM:
def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs):
def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, cache=True, **kwargs):
self.model = model
self.model_type = model_type
self.cache = cache
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.history = []

if "o1-" in model:
assert max_tokens >= 5000 and temperature == 1.0, \
"OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
assert (
max_tokens >= 5000 and temperature == 1.0
), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"

def __call__(self, prompt=None, messages=None, **kwargs):
# Build the request.
cache = kwargs.pop("cache", self.cache)
messages = messages or [{"role": "user", "content": prompt}]
kwargs = {**self.kwargs, **kwargs}

# Make the request and handle LRU & disk caching.
if self.model_type == "chat": completion = cached_litellm_completion if cache else litellm_completion
else: completion = cached_litellm_text_completion if cache else litellm_text_completion
if self.model_type == "chat":
completion = cached_litellm_completion if cache else litellm_completion
else:
completion = cached_litellm_text_completion if cache else litellm_text_completion

response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs)))
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
Expand All @@ -63,8 +57,9 @@ def __call__(self, prompt=None, messages=None, **kwargs):
model_type=self.model_type,
)
self.history.append(entry)

return outputs

def inspect_history(self, n: int = 1):
_inspect_history(self, n)

Expand All @@ -73,14 +68,17 @@ def inspect_history(self, n: int = 1):
def cached_litellm_completion(request):
return litellm_completion(request, cache={"no-cache": False, "no-store": False})


def litellm_completion(request, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
return litellm.completion(cache=cache, **kwargs)


@functools.lru_cache(maxsize=None)
def cached_litellm_text_completion(request):
return litellm_text_completion(request, cache={"no-cache": False, "no-store": False})


def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)

Expand All @@ -93,32 +91,40 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True})
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

# Build the prompt from the messages.
prompt = '\n\n'.join([x['content'] for x in kwargs.pop("messages")] + ['BEGIN RESPONSE:'])
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])

return litellm.text_completion(cache=cache, model=f'text-completion-openai/{model}', api_key=api_key,
api_base=api_base, prompt=prompt, **kwargs)
return litellm.text_completion(
cache=cache,
model=f"text-completion-openai/{model}",
api_key=api_key,
api_base=api_base,
prompt=prompt,
**kwargs,
)


def _green(text: str, end: str = "\n"):
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end


def _red(text: str, end: str = "\n"):
return "\x1b[31m" + str(text) + "\x1b[0m" + end


def _inspect_history(lm, n: int = 1):
"""Prints the last n prompts and their completions."""

for item in lm.history[-n:]:
messages = item["messages"] or [{"role": "user", "content": item['prompt']}]
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
outputs = item["outputs"]
timestamp = item.get("timestamp", "Unknown time")

print("\n\n\n")
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")

for msg in messages:
print(_red(f"{msg['role'].capitalize()} message:"))
print(msg['content'].strip())
print(msg["content"].strip())
print("\n")

print(_red("Response:"))
Expand All @@ -127,5 +133,5 @@ def _inspect_history(lm, n: int = 1):
if len(outputs) > 1:
choices_text = f" \t (and {len(outputs)-1} other completions)"
print(_red(choices_text, end=""))
print("\n\n\n")

print("\n\n\n")

0 comments on commit 92d25a7

Please sign in to comment.