From effc60920bca2a0d1d37beaff03d2bf18bb7d392 Mon Sep 17 00:00:00 2001 From: david qiu Date: Thu, 18 Jan 2024 10:10:29 -0800 Subject: [PATCH] Update Cohere model IDs (#584) * update Cohere model IDs * get provider name from class attr instead of instance attr --- packages/jupyter-ai-magics/jupyter_ai_magics/providers.py | 3 ++- packages/jupyter-ai/jupyter_ai/chat_handlers/default.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 4192ffbb3..7cd2c9c82 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -363,7 +363,8 @@ def allows_concurrency(self): class CohereProvider(BaseProvider, Cohere): id = "cohere" name = "Cohere" - models = ["medium", "xlarge"] + # Source: https://docs.cohere.com/reference/generate + models = ["command", "command-nightly", "command-light", "command-light-nightly"] model_id_key = "model" pypi_package_deps = ["cohere"] auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3a76fba44..0db83afdd 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -53,7 +53,7 @@ def create_llm_chain( prompt_template = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format( - provider_name=llm.name, local_model_id=llm.model_id + provider_name=provider.name, local_model_id=llm.model_id ), MessagesPlaceholder(variable_name="history"), HumanMessagePromptTemplate.from_template("{input}"), @@ -64,7 +64,7 @@ def create_llm_chain( prompt_template = PromptTemplate( input_variables=["history", "input"], template=SYSTEM_PROMPT.format( - provider_name=llm.name, local_model_id=llm.model_id + provider_name=provider.name, local_model_id=llm.model_id ) + "\n\n" + DEFAULT_TEMPLATE,