From abedb08121d3db7b258d8ebd6a933281ba8a1e28 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 25 Sep 2023 13:28:48 -0700 Subject: [PATCH] Adds chat anthropic provider, new models (#391) * Adds chat anthropic provider, new models * Added docs for anthropic chat --- docs/source/users/index.md | 2 ++ .../jupyter_ai_magics/__init__.py | 1 + .../jupyter_ai_magics/magics.py | 15 +++++++-- .../jupyter_ai_magics/providers.py | 32 ++++++++++++++++++- packages/jupyter-ai-magics/pyproject.toml | 3 +- 5 files changed, 49 insertions(+), 4 deletions(-) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 83f91781d..a4c1bcda9 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -116,6 +116,7 @@ Jupyter AI supports the following model providers: |---------------------|----------------------|----------------------------|---------------------------------| | AI21 | `ai21` | `AI21_API_KEY` | `ai21` | | Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` | +| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `anthropic` | | Bedrock | `amazon-bedrock` | N/A | `boto3` | | Cohere | `cohere` | `COHERE_API_KEY` | `cohere` | | Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` | @@ -464,6 +465,7 @@ We currently support the following language model providers: - `ai21` - `anthropic` +- `anthropic-chat` - `cohere` - `huggingface_hub` - `openai` diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index 60020823a..f419fdedd 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -16,6 +16,7 @@ AzureChatOpenAIProvider, BaseProvider, BedrockProvider, + ChatAnthropicProvider, ChatOpenAINewProvider, ChatOpenAIProvider, CohereProvider, diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 47ee9558f..667010dcc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -13,6 +13,7 @@ from IPython.display import HTML, JSON, Markdown, Math from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers from langchain.chains import LLMChain +from langchain.schema import HumanMessage from .parsers import ( CellArgs, @@ -125,6 +126,12 @@ def __init__(self, shell): "no longer supported. Instead, please use: " "`from langchain.chat_models import ChatOpenAI`", ) + # suppress warning when using old Anthropic provider + warnings.filterwarnings( + "ignore", + message="This Anthropic LLM is deprecated. Please use " + "`from langchain.chat_models import ChatAnthropic` instead", + ) self.providers = get_lm_providers() @@ -529,8 +536,12 @@ def run_ai_cell(self, args: CellArgs, prompt: str): ip = get_ipython() prompt = prompt.format_map(FormatDict(ip.user_ns)) - # generate output from model via provider - result = provider.generate([prompt]) + if provider_id == "anthropic-chat": + result = provider.generate([[HumanMessage(content=prompt)]]) + else: + # generate output from model via provider + result = provider.generate([prompt]) + output = result.generations[0][0].text # if openai-chat, append exchange to transcript diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 149cf9fd8..041b870d5 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -9,7 +9,7 @@ from jsonpath_ng import parse from langchain import PromptTemplate -from langchain.chat_models import AzureChatOpenAI, ChatOpenAI +from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatOpenAI from langchain.llms import ( AI21, Anthropic, @@ -235,8 +235,28 @@ class AnthropicProvider(BaseProvider, Anthropic): "claude-v1.0", "claude-v1.2", "claude-2", + "claude-2.0", "claude-instant-v1", "claude-instant-v1.0", + "claude-instant-v1.2", + ] + model_id_key = "model" + pypi_package_deps = ["anthropic"] + auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") + + +class ChatAnthropicProvider(BaseProvider, ChatAnthropic): + id = "anthropic-chat" + name = "ChatAnthropic" + models = [ + "claude-v1", + "claude-v1.0", + "claude-v1.2", + "claude-2", + "claude-2.0", + "claude-instant-v1", + "claude-instant-v1.0", + "claude-instant-v1.2", ] model_id_key = "model" pypi_package_deps = ["anthropic"] @@ -582,10 +602,20 @@ class BedrockProvider(BaseProvider, Bedrock): "anthropic.claude-v2", "ai21.j2-jumbo-instruct", "ai21.j2-grande-instruct", + "ai21.j2-mid", + "ai21.j2-ultra", ] model_id_key = "model_id" pypi_package_deps = ["boto3"] auth_strategy = AwsAuthStrategy() + fields = [ + TextField( + key="credentials_profile_name", + label="AWS profile (optional)", + format="text", + ), + TextField(key="region_name", label="Region name (optional)", format="text"), + ] async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 38d885e7a..2d059a27c 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -44,7 +44,7 @@ test = [ all = [ "ai21", - "anthropic~=0.2.10", + "anthropic~=0.3.0", "cohere", "gpt4all", "huggingface_hub", @@ -66,6 +66,7 @@ openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider" azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider" sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider" amazon-bedrock = "jupyter_ai_magics:BedrockProvider" +anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider" [project.entry-points."jupyter_ai.embeddings_model_providers"] cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"