From 14965201722fa42256d261a70fe3445f979127fe Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 25 Sep 2023 13:28:48 -0700 Subject: [PATCH 1/6] Adds chat anthropic provider, new models (#391) * Adds chat anthropic provider, new models * Added docs for anthropic chat --- docs/source/users/index.md | 2 ++ .../jupyter_ai_magics/__init__.py | 1 + .../jupyter_ai_magics/magics.py | 22 ++++++++++-- .../jupyter_ai_magics/providers.py | 34 ++++++++++++++++++- packages/jupyter-ai-magics/pyproject.toml | 3 +- 5 files changed, 58 insertions(+), 4 deletions(-) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index f4b9de285..a527de0c5 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -39,6 +39,7 @@ Jupyter AI supports the following model providers: |---------------------|----------------------|----------------------------|---------------------------------| | AI21 | `ai21` | `AI21_API_KEY` | `ai21` | | Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` | +| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `anthropic` | | Bedrock | `amazon-bedrock` | N/A | `boto3` | | Cohere | `cohere` | `COHERE_API_KEY` | `cohere` | | Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` | @@ -437,6 +438,7 @@ We currently support the following language model providers: - `ai21` - `anthropic` +- `anthropic-chat` - `cohere` - `huggingface_hub` - `openai` diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index 60020823a..f419fdedd 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -16,6 +16,7 @@ AzureChatOpenAIProvider, BaseProvider, BedrockProvider, + ChatAnthropicProvider, ChatOpenAINewProvider, ChatOpenAIProvider, CohereProvider, diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 7745b6f2f..ec03f2e61 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -13,6 +13,7 @@ from IPython.display import HTML, JSON, Markdown, Math from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers from langchain.chains import LLMChain +from langchain.schema import HumanMessage from .parsers import ( CellArgs, @@ -138,6 +139,12 @@ def __init__(self, shell): "no longer supported. Instead, please use: " "`from langchain.chat_models import ChatOpenAI`", ) + # suppress warning when using old Anthropic provider + warnings.filterwarnings( + "ignore", + message="This Anthropic LLM is deprecated. Please use " + "`from langchain.chat_models import ChatAnthropic` instead", + ) self.providers = get_lm_providers() @@ -542,8 +549,19 @@ def run_ai_cell(self, args: CellArgs, prompt: str): provider = Provider(**provider_params) - # generate output from model via provider - result = provider.generate([prompt]) + # Apply a prompt template. + prompt = provider.get_prompt_template(args.format).format(prompt=prompt) + + # interpolate user namespace into prompt + ip = get_ipython() + prompt = prompt.format_map(FormatDict(ip.user_ns)) + + if provider_id == "anthropic-chat": + result = provider.generate([[HumanMessage(content=prompt)]]) + else: + # generate output from model via provider + result = provider.generate([prompt]) + output = result.generations[0][0].text # if openai-chat, append exchange to transcript diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 153362669..4b5edb547 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -8,7 +8,9 @@ from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union from jsonpath_ng import parse -from langchain.chat_models import AzureChatOpenAI, ChatOpenAI +from langchain import PromptTemplate +from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatOpenAI + from langchain.llms import ( AI21, Anthropic, @@ -183,8 +185,28 @@ class AnthropicProvider(BaseProvider, Anthropic): "claude-v1.0", "claude-v1.2", "claude-2", + "claude-2.0", + "claude-instant-v1", + "claude-instant-v1.0", + "claude-instant-v1.2", + ] + model_id_key = "model" + pypi_package_deps = ["anthropic"] + auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") + + +class ChatAnthropicProvider(BaseProvider, ChatAnthropic): + id = "anthropic-chat" + name = "ChatAnthropic" + models = [ + "claude-v1", + "claude-v1.0", + "claude-v1.2", + "claude-2", + "claude-2.0", "claude-instant-v1", "claude-instant-v1.0", + "claude-instant-v1.2", ] model_id_key = "model" pypi_package_deps = ["anthropic"] @@ -530,10 +552,20 @@ class BedrockProvider(BaseProvider, Bedrock): "anthropic.claude-v2", "ai21.j2-jumbo-instruct", "ai21.j2-grande-instruct", + "ai21.j2-mid", + "ai21.j2-ultra", ] model_id_key = "model_id" pypi_package_deps = ["boto3"] auth_strategy = AwsAuthStrategy() + fields = [ + TextField( + key="credentials_profile_name", + label="AWS profile (optional)", + format="text", + ), + TextField(key="region_name", label="Region name (optional)", format="text"), + ] async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 38d885e7a..2d059a27c 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -44,7 +44,7 @@ test = [ all = [ "ai21", - "anthropic~=0.2.10", + "anthropic~=0.3.0", "cohere", "gpt4all", "huggingface_hub", @@ -66,6 +66,7 @@ openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider" azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider" sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider" amazon-bedrock = "jupyter_ai_magics:BedrockProvider" +anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider" [project.entry-points."jupyter_ai.embeddings_model_providers"] cohere = "jupyter_ai_magics:CohereEmbeddingsProvider" From b0827bca4306a8003da7b767e9e1bd2355a473d0 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 25 Sep 2023 13:28:48 -0700 Subject: [PATCH 2/6] Upgraded LangChain, fixed prompts for Bedrock --- .../jupyter_ai_magics/__init__.py | 1 + .../jupyter_ai_magics/magics.py | 5 +- .../jupyter_ai_magics/providers.py | 80 +++++++++++++++++-- packages/jupyter-ai-magics/pyproject.toml | 3 +- .../jupyter_ai/chat_handlers/ask.py | 4 +- .../jupyter_ai/chat_handlers/default.py | 63 +++++++++++---- .../jupyter_ai/chat_handlers/learn.py | 2 +- packages/jupyter-ai/pyproject.toml | 2 +- 8 files changed, 133 insertions(+), 27 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index f419fdedd..f87992ae1 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -15,6 +15,7 @@ AnthropicProvider, AzureChatOpenAIProvider, BaseProvider, + BedrockChatProvider, BedrockProvider, ChatAnthropicProvider, ChatOpenAINewProvider, diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index ec03f2e61..28957b50a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -430,6 +430,9 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider: return self.providers[provider_id] + def _is_chat_model(self, provider_id: str) -> bool: + return provider_id in ["anthropic-chat", "bedrock-chat"] + def display_output(self, output, display_format, md): # build output display DisplayClass = DISPLAYS_BY_FORMAT[display_format] @@ -556,7 +559,7 @@ def run_ai_cell(self, args: CellArgs, prompt: str): ip = get_ipython() prompt = prompt.format_map(FormatDict(ip.user_ns)) - if provider_id == "anthropic-chat": + if self._is_chat_model(provider.id): result = provider.generate([[HumanMessage(content=prompt)]]) else: # generate output from model via provider diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 4b5edb547..36c0ea6d5 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -8,8 +8,13 @@ from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union from jsonpath_ng import parse -from langchain import PromptTemplate -from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatOpenAI + +from langchain.chat_models import ( + AzureChatOpenAI, + BedrockChat, + ChatAnthropic, + ChatOpenAI, +) from langchain.llms import ( AI21, @@ -24,6 +29,8 @@ ) from langchain.llms.sagemaker_endpoint import LLMContentHandler from langchain.llms.utils import enforce_stop_tokens +from langchain.prompts import PromptTemplate +from langchain.schema import LLMResult from langchain.utils import get_from_dict_or_env from pydantic import BaseModel, Extra, root_validator @@ -154,6 +161,35 @@ async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]: _call_with_args = functools.partial(self._call, *args, **kwargs) return await loop.run_in_executor(executor, _call_with_args) + async def _generate_in_executor( + self, *args, **kwargs + ) -> Coroutine[Any, Any, LLMResult]: + """ + Calls self._call() asynchronously in a separate thread for providers + without an async implementation. Requires the event loop to be running. + """ + executor = ThreadPoolExecutor(max_workers=1) + loop = asyncio.get_running_loop() + _call_with_args = functools.partial(self._generate, *args, **kwargs) + return await loop.run_in_executor(executor, _call_with_args) + + def update_prompt_template(self, format: str, template: str): + """ + Changes the class-level prompt template for a given format. + """ + self.prompt_templates[format] = PromptTemplate.from_template(template) + + def get_prompt_template(self, format) -> PromptTemplate: + """ + Produce a prompt template suitable for use with a particular model, to + produce output in a desired format. + """ + + if format in self.prompt_templates: + return self.prompt_templates[format] + else: + return self.prompt_templates["text"] # Default to plain format + class AI21Provider(BaseProvider, AI21): id = "ai21" @@ -546,14 +582,41 @@ class BedrockProvider(BaseProvider, Bedrock): id = "bedrock" name = "Amazon Bedrock" models = [ - "amazon.titan-tg1-large", + "amazon.titan-text-express-v1", "anthropic.claude-v1", + "anthropic.claude-v2", "anthropic.claude-instant-v1", + "ai21.j2-ultra-v1", + "ai21.j2-mid-v1", + "cohere.command-text-v14", + ] + model_id_key = "model_id" + pypi_package_deps = ["boto3"] + auth_strategy = AwsAuthStrategy() + fields = [ + TextField( + key="credentials_profile_name", + label="AWS profile (optional)", + format="text", + ), + TextField(key="region_name", label="Region name (optional)", format="text"), + ] + + async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: + return await self._call_in_executor(*args, **kwargs) + + +class BedrockChatProvider(BaseProvider, BedrockChat): + id = "bedrock-chat" + name = "Amazon Bedrock Chat" + models = [ + "amazon.titan-text-express-v1", + "anthropic.claude-v1", "anthropic.claude-v2", - "ai21.j2-jumbo-instruct", - "ai21.j2-grande-instruct", - "ai21.j2-mid", - "ai21.j2-ultra", + "anthropic.claude-instant-v1", + "ai21.j2-ultra-v1", + "ai21.j2-mid-v1", + "cohere.command-text-v14", ] model_id_key = "model_id" pypi_package_deps = ["boto3"] @@ -569,3 +632,6 @@ class BedrockProvider(BaseProvider, Bedrock): async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) + + async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]: + return await self._generate_in_executor(*args, **kwargs) diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 2d059a27c..b119dc172 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "ipython", "pydantic~=1.0", "importlib_metadata>=5.2.0", - "langchain==0.0.277", + "langchain==0.0.306", "typing_extensions>=4.5.0", "click~=8.0", "jsonpath-ng>=1.5.3,<2", @@ -67,6 +67,7 @@ azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider" sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider" amazon-bedrock = "jupyter_ai_magics:BedrockProvider" anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider" +amazon-bedrock-chat = "jupyter_ai_magics:BedrockChatProvider" [project.entry-points."jupyter_ai.embeddings_model_providers"] cohere = "jupyter_ai_magics:CohereEmbeddingsProvider" diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 88ddd9c8f..cad14b0e5 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -11,7 +11,7 @@ class AskChatHandler(BaseChatHandler): """Processes messages prefixed with /ask. This actor will send the message as input to a RetrieverQA chain, that - follows the Retrieval and Generation (RAG) tehnique to + follows the Retrieval and Generation (RAG) technique to query the documents from the index, and sends this context to the LLM to generate the final reply. """ @@ -29,7 +29,7 @@ def create_llm_chain( self.llm = provider(**provider_params) self.chat_history = [] self.llm_chain = ConversationalRetrievalChain.from_llm( - self.llm, self._retriever + self.llm, self._retriever, verbose=True ) async def _process_message(self, message: HumanChatMessage): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index c11b63278..c468cc6f2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,8 +1,12 @@ -from typing import Dict, List, Type +from typing import Any, Dict, List, Type from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage -from jupyter_ai_magics.providers import BaseProvider -from langchain import ConversationChain +from jupyter_ai_magics.providers import ( + BaseProvider, + BedrockChatProvider, + BedrockProvider, +) +from langchain.chains import ConversationChain from langchain.memory import ConversationBufferWindowMemory from langchain.prompts import ( ChatPromptTemplate, @@ -10,7 +14,8 @@ MessagesPlaceholder, SystemMessagePromptTemplate, ) -from langchain.schema import AIMessage +from langchain.schema import AIMessage, ChatMessage +from langchain.schema.messages import BaseMessage from .base import BaseChatHandler @@ -26,6 +31,20 @@ """.strip() +class HistoryPlaceholderTemplate(MessagesPlaceholder): + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + values = super().format_messages(**kwargs) + corrected_values = [] + for v in values: + if isinstance(v, AIMessage): + corrected_values.append( + ChatMessage(role="Assistant", content=v.content) + ) + else: + corrected_values.append(v) + return corrected_values + + class DefaultChatHandler(BaseChatHandler): def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): super().__init__(*args, **kwargs) @@ -36,16 +55,32 @@ def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): llm = provider(**provider_params) - prompt_template = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format( - provider_name=llm.name, local_model_id=llm.model_id - ), - MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}"), - AIMessage(content=""), - ] - ) + if provider == BedrockChatProvider or provider == BedrockProvider: + prompt_template = ChatPromptTemplate.from_messages( + [ + ChatMessage( + role="Instructions", + content=SYSTEM_PROMPT.format( + provider_name=llm.name, local_model_id=llm.model_id + ), + ), + HistoryPlaceholderTemplate(variable_name="history"), + HumanMessagePromptTemplate.from_template("{input}"), + ChatMessage(role="Assistant", content=""), + ] + ) + else: + prompt_template = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format( + provider_name=llm.name, local_model_id=llm.model_id + ), + MessagesPlaceholder(variable_name="history"), + HumanMessagePromptTemplate.from_template("{input}"), + AIMessage(content=""), + ] + ) + self.llm = llm self.llm_chain = ConversationChain( llm=llm, prompt=prompt_template, verbose=True, memory=self.memory diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 712444f3c..2d011e522 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -15,7 +15,6 @@ IndexMetadata, ) from jupyter_core.paths import jupyter_data_dir -from langchain import FAISS from langchain.schema import BaseRetriever, Document from langchain.text_splitter import ( LatexTextSplitter, @@ -23,6 +22,7 @@ PythonCodeTextSplitter, RecursiveCharacterTextSplitter, ) +from langchain.vectorstores import FAISS from .base import BaseChatHandler diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index b3d8918a7..0b377dd51 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "openai~=0.26", "aiosqlite>=0.18", "importlib_metadata>=5.2.0", - "langchain==0.0.277", + "langchain==0.0.306", "tiktoken", # required for OpenAIEmbeddings "jupyter_ai_magics", "dask[distributed]", From efcd27f1b632b6fc2325c43219cd0e68243a1e79 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 2 Oct 2023 19:39:52 -0700 Subject: [PATCH 3/6] Updated docs for bedrock-chat --- docs/source/users/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index a527de0c5..5ef4e4ed6 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -41,6 +41,7 @@ Jupyter AI supports the following model providers: | Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` | | Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `anthropic` | | Bedrock | `amazon-bedrock` | N/A | `boto3` | +| Bedrock (chat) | `amazon-bedrock-chat`| N/A | `boto3` | | Cohere | `cohere` | `COHERE_API_KEY` | `cohere` | | Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` | | OpenAI | `openai` | `OPENAI_API_KEY` | `openai` | From 14f426e3ff5d82103ce4a26529450b1d1cc79966 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 4 Oct 2023 22:52:45 -0700 Subject: [PATCH 4/6] Added bedrock embeddings, refactored chat vs reg models --- .../jupyter_ai_magics/__init__.py | 1 + .../jupyter_ai_magics/embedding_providers.py | 20 ++++++- .../jupyter_ai_magics/magics.py | 5 +- .../jupyter_ai_magics/providers.py | 16 ++--- packages/jupyter-ai-magics/pyproject.toml | 3 +- .../jupyter_ai/chat_handlers/base.py | 1 + .../jupyter_ai/chat_handlers/default.py | 58 +++++++------------ packages/jupyter-ai/pyproject.toml | 2 +- 8 files changed, 51 insertions(+), 55 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index f87992ae1..ba756a452 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -2,6 +2,7 @@ # expose embedding model providers on the package root from .embedding_providers import ( + BedrockEmbeddingsProvider, CohereEmbeddingsProvider, HfHubEmbeddingsProvider, OpenAIEmbeddingsProvider, diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index bdfd7012c..5fe522beb 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -1,7 +1,13 @@ from typing import ClassVar, List, Type -from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy, Field +from jupyter_ai_magics.providers import ( + AuthStrategy, + AwsAuthStrategy, + EnvAuthStrategy, + Field, +) from langchain.embeddings import ( + BedrockEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings, OpenAIEmbeddings, @@ -54,7 +60,8 @@ def __init__(self, *args, **kwargs): ) model_kwargs = {} - model_kwargs[self.__class__.model_id_key] = kwargs["model_id"] + if self.__class__.model_id_key != "model_id": + model_kwargs[self.__class__.model_id_key] = kwargs["model_id"] super().__init__(*args, **kwargs, **model_kwargs) @@ -88,3 +95,12 @@ class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings): pypi_package_deps = ["huggingface_hub", "ipywidgets"] auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN") registry = True + + +class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings): + id = "bedrock" + name = "Bedrock" + models = ["amazon.titan-embed-text-v1"] + model_id_key = "model_id" + pypi_package_deps = ["boto3"] + auth_strategy = AwsAuthStrategy() diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 28957b50a..025d32dba 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -430,9 +430,6 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider: return self.providers[provider_id] - def _is_chat_model(self, provider_id: str) -> bool: - return provider_id in ["anthropic-chat", "bedrock-chat"] - def display_output(self, output, display_format, md): # build output display DisplayClass = DISPLAYS_BY_FORMAT[display_format] @@ -559,7 +556,7 @@ def run_ai_cell(self, args: CellArgs, prompt: str): ip = get_ipython() prompt = prompt.format_map(FormatDict(ip.user_ns)) - if self._is_chat_model(provider.id): + if provider.is_chat_provider(provider): result = provider.generate([[HumanMessage(content=prompt)]]) else: # generate output from model via provider diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 36c0ea6d5..df4e22e26 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -8,14 +8,13 @@ from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union from jsonpath_ng import parse - from langchain.chat_models import ( AzureChatOpenAI, BedrockChat, ChatAnthropic, ChatOpenAI, ) - +from langchain.chat_models.base import BaseChatModel from langchain.llms import ( AI21, Anthropic, @@ -165,7 +164,7 @@ async def _generate_in_executor( self, *args, **kwargs ) -> Coroutine[Any, Any, LLMResult]: """ - Calls self._call() asynchronously in a separate thread for providers + Calls self._generate() asynchronously in a separate thread for providers without an async implementation. Requires the event loop to be running. """ executor = ThreadPoolExecutor(max_workers=1) @@ -190,6 +189,10 @@ def get_prompt_template(self, format) -> PromptTemplate: else: return self.prompt_templates["text"] # Default to plain format + @property + def is_chat_provider(self): + return isinstance(self, BaseChatModel) + class AI21Provider(BaseProvider, AI21): id = "ai21" @@ -583,9 +586,6 @@ class BedrockProvider(BaseProvider, Bedrock): name = "Amazon Bedrock" models = [ "amazon.titan-text-express-v1", - "anthropic.claude-v1", - "anthropic.claude-v2", - "anthropic.claude-instant-v1", "ai21.j2-ultra-v1", "ai21.j2-mid-v1", "cohere.command-text-v14", @@ -610,13 +610,9 @@ class BedrockChatProvider(BaseProvider, BedrockChat): id = "bedrock-chat" name = "Amazon Bedrock Chat" models = [ - "amazon.titan-text-express-v1", "anthropic.claude-v1", "anthropic.claude-v2", "anthropic.claude-instant-v1", - "ai21.j2-ultra-v1", - "ai21.j2-mid-v1", - "cohere.command-text-v14", ] model_id_key = "model_id" pypi_package_deps = ["boto3"] diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index b119dc172..a310621c1 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "ipython", "pydantic~=1.0", "importlib_metadata>=5.2.0", - "langchain==0.0.306", + "langchain==0.0.308", "typing_extensions>=4.5.0", "click~=8.0", "jsonpath-ng>=1.5.3,<2", @@ -70,6 +70,7 @@ anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider" amazon-bedrock-chat = "jupyter_ai_magics:BedrockChatProvider" [project.entry-points."jupyter_ai.embeddings_model_providers"] +bedrock = "jupyter_ai_magics:BedrockEmbeddingsProvider" cohere = "jupyter_ai_magics:CohereEmbeddingsProvider" huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider" openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider" diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 87da6d214..6ad4e4ec8 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -9,6 +9,7 @@ from jupyter_ai.config_manager import ConfigManager, Logger from jupyter_ai.models import AgentChatMessage, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider +from langchain.chat_models.base import BaseChatModel if TYPE_CHECKING: from jupyter_ai.handlers import RootChatHandler diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index c468cc6f2..1c20fad9a 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,21 +1,17 @@ -from typing import Any, Dict, List, Type +from typing import Dict, List, Type from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage -from jupyter_ai_magics.providers import ( - BaseProvider, - BedrockChatProvider, - BedrockProvider, -) +from jupyter_ai_magics.providers import BaseProvider from langchain.chains import ConversationChain +from langchain.chat_models.base import BaseChatModel from langchain.memory import ConversationBufferWindowMemory from langchain.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, + PromptTemplate, SystemMessagePromptTemplate, ) -from langchain.schema import AIMessage, ChatMessage -from langchain.schema.messages import BaseMessage from .base import BaseChatHandler @@ -30,19 +26,10 @@ The following is a friendly conversation between you and a human. """.strip() - -class HistoryPlaceholderTemplate(MessagesPlaceholder): - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - values = super().format_messages(**kwargs) - corrected_values = [] - for v in values: - if isinstance(v, AIMessage): - corrected_values.append( - ChatMessage(role="Assistant", content=v.content) - ) - else: - corrected_values.append(v) - return corrected_values +DEFAULT_TEMPLATE = """Current conversation: +{history} +Human: {input} +AI:""" class DefaultChatHandler(BaseChatHandler): @@ -55,21 +42,8 @@ def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): llm = provider(**provider_params) - if provider == BedrockChatProvider or provider == BedrockProvider: - prompt_template = ChatPromptTemplate.from_messages( - [ - ChatMessage( - role="Instructions", - content=SYSTEM_PROMPT.format( - provider_name=llm.name, local_model_id=llm.model_id - ), - ), - HistoryPlaceholderTemplate(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}"), - ChatMessage(role="Assistant", content=""), - ] - ) - else: + + if llm.is_chat_provider: prompt_template = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format( @@ -77,9 +51,19 @@ def create_llm_chain( ), MessagesPlaceholder(variable_name="history"), HumanMessagePromptTemplate.from_template("{input}"), - AIMessage(content=""), ] ) + self.memory = ConversationBufferWindowMemory(return_messages=True, k=2) + else: + prompt_template = PromptTemplate( + input_variables=["history", "input"], + template=SYSTEM_PROMPT.format( + provider_name=llm.name, local_model_id=llm.model_id + ) + + "\n\n" + + DEFAULT_TEMPLATE, + ) + self.memory = ConversationBufferWindowMemory(k=2) self.llm = llm self.llm_chain = ConversationChain( diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 0b377dd51..fd8a06934 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "openai~=0.26", "aiosqlite>=0.18", "importlib_metadata>=5.2.0", - "langchain==0.0.306", + "langchain==0.0.308", "tiktoken", # required for OpenAIEmbeddings "jupyter_ai_magics", "dask[distributed]", From f1c50e3b176a5108b31c387ae93f25b700f7b6ae Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 4 Oct 2023 23:07:08 -0700 Subject: [PATCH 5/6] Fixed magics --- packages/jupyter-ai-magics/jupyter_ai_magics/magics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 025d32dba..6f1d29546 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -556,7 +556,7 @@ def run_ai_cell(self, args: CellArgs, prompt: str): ip = get_ipython() prompt = prompt.format_map(FormatDict(ip.user_ns)) - if provider.is_chat_provider(provider): + if provider.is_chat_provider: result = provider.generate([[HumanMessage(content=prompt)]]) else: # generate output from model via provider From fb18c38190748928780f642ccaf19ad1ea777a7a Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Thu, 5 Oct 2023 07:20:26 -0700 Subject: [PATCH 6/6] Removed unused import --- packages/jupyter-ai/jupyter_ai/chat_handlers/default.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 1c20fad9a..c674a383b 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -3,7 +3,6 @@ from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider from langchain.chains import ConversationChain -from langchain.chat_models.base import BaseChatModel from langchain.memory import ConversationBufferWindowMemory from langchain.prompts import ( ChatPromptTemplate,