Skip to content

Commit

Permalink
Added bedrock embeddings, refactored chat vs reg models
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Oct 5, 2023
1 parent 7ab465b commit f8a0ab2
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 55 deletions.
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# expose embedding model providers on the package root
from .embedding_providers import (
BedrockEmbeddingsProvider,
CohereEmbeddingsProvider,
HfHubEmbeddingsProvider,
OpenAIEmbeddingsProvider,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import ClassVar, List, Type

from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy, Field
from jupyter_ai_magics.providers import (
AuthStrategy,
AwsAuthStrategy,
EnvAuthStrategy,
Field,
)
from langchain.embeddings import (
BedrockEmbeddings,
CohereEmbeddings,
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
Expand Down Expand Up @@ -54,7 +60,8 @@ def __init__(self, *args, **kwargs):
)

model_kwargs = {}
model_kwargs[self.__class__.model_id_key] = kwargs["model_id"]
if self.__class__.model_id_key != "model_id":
model_kwargs[self.__class__.model_id_key] = kwargs["model_id"]

super().__init__(*args, **kwargs, **model_kwargs)

Expand Down Expand Up @@ -88,3 +95,12 @@ class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings):
pypi_package_deps = ["huggingface_hub", "ipywidgets"]
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")
registry = True


class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
id = "bedrock"
name = "Bedrock"
models = ["amazon.titan-embed-text-v1"]
model_id_key = "model_id"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()
5 changes: 1 addition & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,6 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:

return self.providers[provider_id]

def _is_chat_model(self, provider_id: str) -> bool:
return provider_id in ["anthropic-chat", "bedrock-chat"]

def display_output(self, output, display_format, md):
# build output display
DisplayClass = DISPLAYS_BY_FORMAT[display_format]
Expand Down Expand Up @@ -539,7 +536,7 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
ip = get_ipython()
prompt = prompt.format_map(FormatDict(ip.user_ns))

if self._is_chat_model(provider.id):
if provider.is_chat_provider(provider):
result = provider.generate([[HumanMessage(content=prompt)]])
else:
# generate output from model via provider
Expand Down
16 changes: 6 additions & 10 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union

from jsonpath_ng import parse

from langchain.chat_models import (
AzureChatOpenAI,
BedrockChat,
ChatAnthropic,
ChatOpenAI,
)

from langchain.chat_models.base import BaseChatModel
from langchain.llms import (
AI21,
Anthropic,
Expand Down Expand Up @@ -199,7 +198,7 @@ async def _generate_in_executor(
self, *args, **kwargs
) -> Coroutine[Any, Any, LLMResult]:
"""
Calls self._call() asynchronously in a separate thread for providers
Calls self._generate() asynchronously in a separate thread for providers
without an async implementation. Requires the event loop to be running.
"""
executor = ThreadPoolExecutor(max_workers=1)
Expand All @@ -224,6 +223,10 @@ def get_prompt_template(self, format) -> PromptTemplate:
else:
return self.prompt_templates["text"] # Default to plain format

@property
def is_chat_provider(self):
return isinstance(self, BaseChatModel)


class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand Down Expand Up @@ -617,9 +620,6 @@ class BedrockProvider(BaseProvider, Bedrock):
name = "Amazon Bedrock"
models = [
"amazon.titan-text-express-v1",
"anthropic.claude-v1",
"anthropic.claude-v2",
"anthropic.claude-instant-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"cohere.command-text-v14",
Expand All @@ -644,13 +644,9 @@ class BedrockChatProvider(BaseProvider, BedrockChat):
id = "bedrock-chat"
name = "Amazon Bedrock Chat"
models = [
"amazon.titan-text-express-v1",
"anthropic.claude-v1",
"anthropic.claude-v2",
"anthropic.claude-instant-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"cohere.command-text-v14",
]
model_id_key = "model_id"
pypi_package_deps = ["boto3"]
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"ipython",
"pydantic~=1.0",
"importlib_metadata>=5.2.0",
"langchain==0.0.306",
"langchain==0.0.308",
"typing_extensions>=4.5.0",
"click~=8.0",
"jsonpath-ng>=1.5.3,<2",
Expand Down Expand Up @@ -70,6 +70,7 @@ anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider"
amazon-bedrock-chat = "jupyter_ai_magics:BedrockChatProvider"

[project.entry-points."jupyter_ai.embeddings_model_providers"]
bedrock = "jupyter_ai_magics:BedrockEmbeddingsProvider"
cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider"
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chat_models.base import BaseChatModel

if TYPE_CHECKING:
from jupyter_ai.handlers import RootChatHandler
Expand Down
58 changes: 21 additions & 37 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from typing import Any, Dict, List, Type
from typing import Dict, List, Type

from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage
from jupyter_ai_magics.providers import (
BaseProvider,
BedrockChatProvider,
BedrockProvider,
)
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationChain
from langchain.chat_models.base import BaseChatModel
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AIMessage, ChatMessage
from langchain.schema.messages import BaseMessage

from .base import BaseChatHandler

Expand All @@ -30,19 +26,10 @@
The following is a friendly conversation between you and a human.
""".strip()


class HistoryPlaceholderTemplate(MessagesPlaceholder):
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
values = super().format_messages(**kwargs)
corrected_values = []
for v in values:
if isinstance(v, AIMessage):
corrected_values.append(
ChatMessage(role="Assistant", content=v.content)
)
else:
corrected_values.append(v)
return corrected_values
DEFAULT_TEMPLATE = """Current conversation:
{history}
Human: {input}
AI:"""


class DefaultChatHandler(BaseChatHandler):
Expand All @@ -55,31 +42,28 @@ def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
if provider == BedrockChatProvider or provider == BedrockProvider:
prompt_template = ChatPromptTemplate.from_messages(
[
ChatMessage(
role="Instructions",
content=SYSTEM_PROMPT.format(
provider_name=llm.name, local_model_id=llm.model_id
),
),
HistoryPlaceholderTemplate(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
ChatMessage(role="Assistant", content=""),
]
)
else:

if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(
provider_name=llm.name, local_model_id=llm.model_id
),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=""),
]
)
self.memory = ConversationBufferWindowMemory(return_messages=True, k=2)
else:
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=SYSTEM_PROMPT.format(
provider_name=llm.name, local_model_id=llm.model_id
)
+ "\n\n"
+ DEFAULT_TEMPLATE,
)
self.memory = ConversationBufferWindowMemory(k=2)

self.llm = llm
self.llm_chain = ConversationChain(
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"openai~=0.26",
"aiosqlite>=0.18",
"importlib_metadata>=5.2.0",
"langchain==0.0.306",
"langchain==0.0.308",
"tiktoken", # required for OpenAIEmbeddings
"jupyter_ai_magics",
"dask[distributed]",
Expand Down

0 comments on commit f8a0ab2

Please sign in to comment.