Skip to content

Commit

Permalink
Adds chat anthropic provider, new models (#391)
Browse files Browse the repository at this point in the history
* Adds chat anthropic provider, new models

* Added docs for anthropic chat
  • Loading branch information
3coins committed Oct 5, 2023
1 parent 491490a commit 2047368
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Jupyter AI supports the following model providers:
|---------------------|----------------------|----------------------------|---------------------------------|
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` |
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `anthropic` |
| Bedrock | `amazon-bedrock` | N/A | `boto3` |
| Cohere | `cohere` | `COHERE_API_KEY` | `cohere` |
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
Expand Down Expand Up @@ -437,6 +438,7 @@ We currently support the following language model providers:

- `ai21`
- `anthropic`
- `anthropic-chat`
- `cohere`
- `huggingface_hub`
- `openai`
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AzureChatOpenAIProvider,
BaseProvider,
BedrockProvider,
ChatAnthropicProvider,
ChatOpenAINewProvider,
ChatOpenAIProvider,
CohereProvider,
Expand Down
22 changes: 20 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from IPython.display import HTML, JSON, Markdown, Math
from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers
from langchain.chains import LLMChain
from langchain.schema import HumanMessage

from .parsers import (
CellArgs,
Expand Down Expand Up @@ -138,6 +139,12 @@ def __init__(self, shell):
"no longer supported. Instead, please use: "
"`from langchain.chat_models import ChatOpenAI`",
)
# suppress warning when using old Anthropic provider
warnings.filterwarnings(
"ignore",
message="This Anthropic LLM is deprecated. Please use "
"`from langchain.chat_models import ChatAnthropic` instead",
)

self.providers = get_lm_providers()

Expand Down Expand Up @@ -542,8 +549,19 @@ def run_ai_cell(self, args: CellArgs, prompt: str):

provider = Provider(**provider_params)

# generate output from model via provider
result = provider.generate([prompt])
# Apply a prompt template.
prompt = provider.get_prompt_template(args.format).format(prompt=prompt)

# interpolate user namespace into prompt
ip = get_ipython()
prompt = prompt.format_map(FormatDict(ip.user_ns))

if provider_id == "anthropic-chat":
result = provider.generate([[HumanMessage(content=prompt)]])
else:
# generate output from model via provider
result = provider.generate([prompt])

output = result.generations[0][0].text

# if openai-chat, append exchange to transcript
Expand Down
34 changes: 33 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union

from jsonpath_ng import parse
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain import PromptTemplate
from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatOpenAI

from langchain.llms import (
AI21,
Anthropic,
Expand Down Expand Up @@ -183,8 +185,28 @@ class AnthropicProvider(BaseProvider, Anthropic):
"claude-v1.0",
"claude-v1.2",
"claude-2",
"claude-2.0",
"claude-instant-v1",
"claude-instant-v1.0",
"claude-instant-v1.2",
]
model_id_key = "model"
pypi_package_deps = ["anthropic"]
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")


class ChatAnthropicProvider(BaseProvider, ChatAnthropic):
id = "anthropic-chat"
name = "ChatAnthropic"
models = [
"claude-v1",
"claude-v1.0",
"claude-v1.2",
"claude-2",
"claude-2.0",
"claude-instant-v1",
"claude-instant-v1.0",
"claude-instant-v1.2",
]
model_id_key = "model"
pypi_package_deps = ["anthropic"]
Expand Down Expand Up @@ -530,10 +552,20 @@ class BedrockProvider(BaseProvider, Bedrock):
"anthropic.claude-v2",
"ai21.j2-jumbo-instruct",
"ai21.j2-grande-instruct",
"ai21.j2-mid",
"ai21.j2-ultra",
]
model_id_key = "model_id"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)
3 changes: 2 additions & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ test = [

all = [
"ai21",
"anthropic~=0.2.10",
"anthropic~=0.3.0",
"cohere",
"gpt4all",
"huggingface_hub",
Expand All @@ -66,6 +66,7 @@ openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider"
azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider"
sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider"

[project.entry-points."jupyter_ai.embeddings_model_providers"]
cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
Expand Down

0 comments on commit 2047368

Please sign in to comment.