Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Suggestion #89

Open
doncat99 opened this issue Nov 14, 2024 · 2 comments
Open

Model Suggestion #89

doncat99 opened this issue Nov 14, 2024 · 2 comments

Comments

@doncat99
Copy link

doncat99 commented Nov 14, 2024

import os
import yaml
from loguru import logger

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.chat_models import ChatLiteLLMRouter
import litellm

litellm.set_verbose=True


# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
class LiteLLMRouterFactory:
    @staticmethod
    def create_model_list():
        config = yaml.load(open('models_support.yaml', 'r', encoding='utf-8'), Loader=yaml.FullLoader)
        model_list = []
        
        is_compatible = os.getenv("OPENAI_API_KEY") and os.getenv("OPENAI_API_BASE")

        if is_compatible:
            model_list.append({
                "model_name": "gpt-4o-mini",
                "provider": "OpenAI",
                "litellm_params": {
                    "model": "gpt-4o-mini",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("GROQ_API_KEY"):
            model_list.append({
                "model_name": "llama-3.1-70b",
                "provider": "Groq",
                "litellm_params": {
                    "model": "groq/llama-3.1-70b",
                    "api_key": os.getenv("GROQ_API_KEY"),
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "llama-3.1-70b",
                "provider": "Groq",
                "litellm_params": {
                    "model": "openai/llama-3.1-70b",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                }
            })

        if os.getenv("GOOGLE_API_KEY"):
            model_list.append({
                "model_name": "gemini-1.5-flash",
                "provider": "Google",
                "litellm_params": {
                    "model": "gemini/gemini-1.5-flash",
                    "api_key": os.getenv("GOOGLE_API_KEY"),
                    "streaming": True,
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "gemini-1.5-flash",
                "provider": "Google",
                "litellm_params": {
                    "model": "openai/gemini-1.5-flash",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("OLLAMA_BASE_URL"):
            model_list.append({
                "model_name": "llama3",
                "provider": "Ollama",
                "litellm_params": {
                    "model": "ollama_chat/llama3",
                    "api_base": os.getenv("OLLAMA_BASE_URL"),
                    "stream": True,
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "llama3",
                "provider": "Ollama",
                "litellm_params": {
                    "model": "openai/llama3",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("ANTHROPIC_API_KEY"):
            model_list.append({
                "model_name": "claude-3-5-sonnet",
                "provider": "Anthropic",
                "litellm_params": {
                    "model": "anthropic/claude-3-5-sonnet-20241022",
                    "api_key": os.getenv("ANTHROPIC_API_KEY"),
                    "temperature": 0.5,
                    "streaming": True
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "claude-3-5-sonnet",
                "provider": "Anthropic",
                "litellm_params": {
                    "model": "openai/claude-3-5-sonnet-20241022",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "temperature": 0.5,
                    "streaming": True,
                }
            })

        if os.getenv("MISTRAL_API_KEY"):
            model_list.append({
                "model_name": "mistral-medium",
                "provider": "Mistral",
                "litellm_params": {
                    "model": "mistral/mistral-medium",
                    "api_key": os.getenv("MISTRAL_API_KEY"),
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "mistral-medium",
                "provider": "Mistral",
                "litellm_params": {
                    "model": "openai/mistral-medium",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })
            
        if os.getenv("USE_AWS_BEDROCK") == "true":
            model_list.append({
                "model_name": "bedrock-haiku",
                "provider":"AWS",
                "litellm_params": {
                    "model": "anthropic.claude-3-5-haiku-20241022-v1:0",
                    "model_id": "anthropic.claude-3-5-haiku-20241022-v1:0",
                    "temperature": 0.5,
                }
            })

        return model_list


class ModelManager:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelManager, cls).__new__(cls)
            cls._instance.model_list = LiteLLMRouterFactory.create_model_list()
            cls._instance.models = cls._instance.initialize_models(cls._instance.model_list)
        return cls._instance
    
    def initialize_models(self, model_list) -> dict[str, BaseChatModel]:
        models: dict[str, BaseChatModel] = {}

        try:    
            router = litellm.Router(model_list=model_list)
            model = ChatLiteLLMRouter(router=router)

        except Exception as e:
            logger.error(f"Error loading models: {e}")
            model_list = []

        for model_info in model_list:
            models[model_info["litellm_params"]["model"]] = model

        if not models:
            logger.error("No LLM available. Please set environment variables to enable at least one LLM.")
            if os.getenv("MODE") == "dev":
                logger.error("FastAPI initialization failed. Please use Ctrl + C to exit uvicorn.")
            exit(1)

        return models

Hi, I just warped the models.py with LiteLLM support. It simplified a lot of work to deal with various LLM providers. See if anyone needs it.

@doncat99
Copy link
Author

import os
import yaml
from loguru import logger

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.chat_models import ChatLiteLLMRouter
import litellm

litellm.set_verbose=True


# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
class LiteLLMRouterFactory:
    @staticmethod
    def create_model_list():
        config = yaml.load(open('models_support.yaml', 'r', encoding='utf-8'), Loader=yaml.FullLoader)
        model_list = []
        
        is_compatible = os.getenv("OPENAI_API_KEY") and os.getenv("OPENAI_API_BASE")

        if is_compatible:
            model_list.append({
                "model_name": "gpt-4o-mini",
                "provider": "OpenAI",
                "litellm_params": {
                    "model": "gpt-4o-mini",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("GROQ_API_KEY"):
            model_list.append({
                "model_name": "llama-3.1-70b",
                "provider": "Groq",
                "litellm_params": {
                    "model": "groq/llama-3.1-70b",
                    "api_key": os.getenv("GROQ_API_KEY"),
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "llama-3.1-70b",
                "provider": "Groq",
                "litellm_params": {
                    "model": "openai/llama-3.1-70b",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                }
            })

        if os.getenv("GOOGLE_API_KEY"):
            model_list.append({
                "model_name": "gemini-1.5-flash",
                "provider": "Google",
                "litellm_params": {
                    "model": "gemini/gemini-1.5-flash",
                    "api_key": os.getenv("GOOGLE_API_KEY"),
                    "streaming": True,
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "gemini-1.5-flash",
                "provider": "Google",
                "litellm_params": {
                    "model": "openai/gemini-1.5-flash",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("OLLAMA_BASE_URL"):
            model_list.append({
                "model_name": "llama3",
                "provider": "Ollama",
                "litellm_params": {
                    "model": "ollama_chat/llama3",
                    "api_base": os.getenv("OLLAMA_BASE_URL"),
                    "stream": True,
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "llama3",
                "provider": "Ollama",
                "litellm_params": {
                    "model": "openai/llama3",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("ANTHROPIC_API_KEY"):
            model_list.append({
                "model_name": "claude-3-5-sonnet",
                "provider": "Anthropic",
                "litellm_params": {
                    "model": "anthropic/claude-3-5-sonnet-20241022",
                    "api_key": os.getenv("ANTHROPIC_API_KEY"),
                    "temperature": 0.5,
                    "streaming": True
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "claude-3-5-sonnet",
                "provider": "Anthropic",
                "litellm_params": {
                    "model": "openai/claude-3-5-sonnet-20241022",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "temperature": 0.5,
                    "streaming": True,
                }
            })

        if os.getenv("MISTRAL_API_KEY"):
            model_list.append({
                "model_name": "mistral-medium",
                "provider": "Mistral",
                "litellm_params": {
                    "model": "mistral/mistral-medium",
                    "api_key": os.getenv("MISTRAL_API_KEY"),
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "mistral-medium",
                "provider": "Mistral",
                "litellm_params": {
                    "model": "openai/mistral-medium",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })
            
        if os.getenv("USE_AWS_BEDROCK") == "true":
            model_list.append({
                "model_name": "bedrock-haiku",
                "provider":"AWS",
                "litellm_params": {
                    "model": "anthropic.claude-3-5-haiku-20241022-v1:0",
                    "model_id": "anthropic.claude-3-5-haiku-20241022-v1:0",
                    "temperature": 0.5,
                }
            })

        return model_list


class ModelManager:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelManager, cls).__new__(cls)
            cls._instance.model_list = LiteLLMRouterFactory.create_model_list()
            cls._instance.models = cls._instance.initialize_models(cls._instance.model_list)
        return cls._instance
    
    def initialize_models(self, model_list) -> dict[str, BaseChatModel]:
        models: dict[str, BaseChatModel] = {}

        try:    
            router = litellm.Router(model_list=model_list)
            model = ChatLiteLLMRouter(router=router)

        except Exception as e:
            logger.error(f"Error loading models: {e}")
            model_list = []

        for model_info in model_list:
            models[model_info["litellm_params"]["model"]] = model

        if not models:
            logger.error("No LLM available. Please set environment variables to enable at least one LLM.")
            if os.getenv("MODE") == "dev":
                logger.error("FastAPI initialization failed. Please use Ctrl + C to exit uvicorn.")
            exit(1)

        return models

Hi, I just warped the models.py with LiteLLM support. It simplified a lot of work to deal with various LLM providers. See if anyone needs it.

There is a bug that langchain simply selects the very first model from the model_list as the default model, see the link github-langchain for more detail. Now I patiently wait for the bug fix.

@JoshuaC215
Copy link
Owner

Hey, I'm not sure if I want to adopt LiteLLM in this repo. I expect in most real usage someone has just a few models they are connecting to so it adds more complexity and dependencies. What do you see as the advantages to this approach?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants