From 4f51cdf6c5fa3ed4d86eaa408a02106f29cb3e0f Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Wed, 4 Oct 2023 15:34:36 -0700 Subject: [PATCH] add support for anthropic, palm, bedrock, ai21, cohere, replicate, togetherai --- chainfury/components/openai/__init__.py | 49 ++++++++++--------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/chainfury/components/openai/__init__.py b/chainfury/components/openai/__init__.py index 4cb4812..d43b378 100644 --- a/chainfury/components/openai/__init__.py +++ b/chainfury/components/openai/__init__.py @@ -1,11 +1,10 @@ import requests from pydantic import BaseModel from typing import Any, List, Union, Dict, Optional - +from litellm import completion, text_completion, AuthenticationError, from chainfury import Secret, model_registry, exponential_backoff, Model, UnAuthException from chainfury.components.const import Env - def openai_completion( model: str, prompt: Union[str, List[Union[str, List[str]]]], @@ -69,13 +68,7 @@ def openai_completion( raise Exception("OpenAI API key not found. Please set OPENAI_TOKEN environment variable or pass through function") def _fn(): - r = requests.post( - "https://api.openai.com/v1/completions", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {openai_api_key}", - }, - json={ + data = { "model": model, "prompt": prompt, "max_tokens": max_tokens, @@ -90,13 +83,14 @@ def _fn(): "best_of": best_of, "logit_bias": logit_bias, "user": user, - }, - ) - if r.status_code == 401: - raise UnAuthException(r.text) - if r.status_code != 200: - raise Exception(f"OpenAI API returned status code {r.status_code}: {r.text}") - return r.json() + } + try: + return text_completion(**data) + except Exception as e: + if isinstance(e, AuthenticationError): + raise UnAuthException(e.text) + else: + raise Exception(f"API returned status code {e.status_code}: {e.message}") return exponential_backoff(_fn, max_retries=retry_count, retry_delay=retry_delay) @@ -195,13 +189,7 @@ def openai_chat( messages = [x.dict(skip_defaults=True) for x in messages] # type: ignore def _fn(): - r = requests.post( - "https://api.openai.com/v1/chat/completions", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {openai_api_key}", - }, - json={ + data = { "model": model, "messages": messages, "max_tokens": max_tokens, @@ -213,13 +201,14 @@ def _fn(): "frequency_penalty": frequency_penalty, "logit_bias": logit_bias, "user": user, - }, - ) - if r.status_code == 401: - raise UnAuthException(r.text) - if r.status_code != 200: - raise Exception(f"OpenAI API returned status code {r.status_code}: {r.text}") - return r.json() + } + try: + return completion(**data) + except Exception as e: + if isinstance(e, AuthenticationError): + raise UnAuthException(e.text) + else: + raise Exception(f"API returned status code {e.status_code}: {e.message}") return exponential_backoff(_fn, max_retries=retry_count, retry_delay=retry_delay)