diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index c33d183..b8e24e5 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -1,3 +1,4 @@ +import os from typing import Any, Callable, Dict, List from langchain.chains import LLMChain, OpenAIModerationChain @@ -14,13 +15,19 @@ from langchain.schema import AIMessage, BaseMessage, HumanMessage, PromptValue, SystemMessage from stampy_chat.env import OPENAI_API_KEY, ANTHROPIC_API_KEY, LANGCHAIN_API_KEY, LANGCHAIN_TRACING_V2, SUMMARY_MODEL -from stampy_chat.settings import Settings, MODELS, OPENAI, ANTRHROPIC +from stampy_chat.settings import Settings, MODELS, OPENAI, ANTHROPIC from stampy_chat.callbacks import StampyCallbackHandler, BroadcastCallbackHandler, LoggerCallbackHandler from stampy_chat.followups import StampyChain from stampy_chat.citations import make_example_selector from langsmith import Client +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + if LANGCHAIN_TRACING_V2 == "true": if not LANGCHAIN_API_KEY: raise Exception("Langsmith tracing is enabled but no api key was provided. Please set LANGCHAIN_API_KEY in the .env file.") @@ -192,7 +199,7 @@ def get_model(**kwargs): model = MODELS.get(kwargs.get('model')) if not model: raise ValueError("No model provided") - if model.publisher == ANTRHROPIC: + if model.publisher == ANTHROPIC: return ChatAnthropicWrapper(anthropic_api_key=ANTHROPIC_API_KEY, **kwargs) if model.publisher == OPENAI: return ChatOpenAI(openai_api_key=OPENAI_API_KEY, **kwargs) @@ -245,7 +252,7 @@ def make_history_summary(settings): def make_prompt(settings, chat_model, callbacks): """Create a proper prompt object will all the nessesery steps.""" # 1. Create the context prompt from items fetched from pinecone - context_template = "[{{reference}}] {{title}} {{authors | join(', ')}} - {{date_published}} {{text}}" + context_template = "\n\n[{{reference}}] {{title}} {{authors | join(', ')}} - {{date_published}} {{text}}\n\n" context_prompt = MessageBufferPromptTemplate( example_selector=make_example_selector(k=settings.topKBlocks, callbacks=callbacks), example_prompt=ChatPromptTemplate.from_template(context_template, template_format="jinja2"), @@ -261,7 +268,10 @@ def make_prompt(settings, chat_model, callbacks): query_prompt = ChatPromptTemplate.from_messages( [ HumanMessage(content=settings.question_prompt), - HumanMessagePromptTemplate.from_template(template='Q: {history_summary}: {query}', role='user'), + HumanMessagePromptTemplate.from_template( + template='{history_summary}{delimiter}{query}', + partial_variables={"delimiter": lambda **kwargs: ": " if kwargs.get("history_summary") else ""} + ), ] ) @@ -334,16 +344,37 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin model=settings.completions ) - chain = make_history_summary(settings) | LLMChain( + history_summary_chain = make_history_summary(settings) + + if history: + history_summary_result = history_summary_chain.invoke({"query": query, 'history': history}) + history_summary = history_summary_result.get('history_summary', '') + else: + history_summary = '' + + delimiter = ": " if history_summary else "" + + llm_chain = LLMChain( llm=chat_model, verbose=False, prompt=make_prompt(settings, chat_model, callbacks), memory=make_memory(settings, history, callbacks) ) + + chain = history_summary_chain | llm_chain if followups: chain = chain | StampyChain(callbacks=callbacks) - result = chain.invoke({"query": query, 'history': history}, {'callbacks': []}) + + chain_input = { + "query": query, + 'history': history, + 'history_summary': history_summary, + 'delimiter': delimiter, + } + + result = chain.invoke(chain_input) + if callback: callback({'state': 'done'}) - callback(None) # make sure the callback handler know that things have ended + callback(None) return result diff --git a/api/src/stampy_chat/env.py b/api/src/stampy_chat/env.py index 49ed832..ee5e72e 100644 --- a/api/src/stampy_chat/env.py +++ b/api/src/stampy_chat/env.py @@ -20,8 +20,8 @@ ### Models ### EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-ada-002") -SUMMARY_MODEL = os.environ.get("SUMMARY_MODEL", "claude-3-5-sonnet-20240620") -COMPLETIONS_MODEL = os.environ.get("COMPLETIONS_MODEL", "claude-3-5-sonnet-20240620") +SUMMARY_MODEL = os.environ.get("SUMMARY_MODEL", "claude-3-5-sonnet-latest") +COMPLETIONS_MODEL = os.environ.get("COMPLETIONS_MODEL", "claude-3-5-sonnet-latest") ### Pinecone ### PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY') diff --git a/api/src/stampy_chat/settings.py b/api/src/stampy_chat/settings.py index ffe1d58..195a014 100644 --- a/api/src/stampy_chat/settings.py +++ b/api/src/stampy_chat/settings.py @@ -59,20 +59,23 @@ 'modes': PROMPT_MODES, } OPENAI = 'openai' -ANTRHROPIC = 'anthropic' +ANTHROPIC = 'anthropic' MODELS = { 'gpt-3.5-turbo': Model(4097, 10, 4096, OPENAI), 'gpt-3.5-turbo-16k': Model(16385, 30, 4096, OPENAI), 'gpt-4': Model(8192, 20, 4096, OPENAI), "gpt-4-turbo-preview": Model(128000, 50, 4096, OPENAI), "gpt-4o": Model(128000, 50, 4096, OPENAI), - "claude-3-opus-20240229": Model(200_000, 50, 4096, ANTRHROPIC), - "claude-3-5-sonnet-20240620": Model(200_000, 50, 4096, ANTRHROPIC), - "claude-3-sonnet-20240229": Model(200_000, 50, 4096, ANTRHROPIC), - "claude-3-haiku-20240307": Model(200_000, 50, 4096, ANTRHROPIC), - "claude-2.1": Model(200_000, 50, 4096, ANTRHROPIC), - "claude-2.0": Model(100_000, 50, 4096, ANTRHROPIC), - "claude-instant-1.2": Model(100_000, 50, 4096, ANTRHROPIC), + "gpt-4o-mini": Model(128000, 50, 4096, OPENAI), + "claude-3-opus-20240229": Model(200_000, 50, 4096, ANTHROPIC), + "claude-3-5-sonnet-20240620": Model(200_000, 50, 4096, ANTHROPIC), + "claude-3-5-sonnet-20241022": Model(200_000, 50, 4096, ANTHROPIC), + "claude-3-5-sonnet-latest": Model(200_000, 50, 4096, ANTHROPIC), + "claude-3-sonnet-20240229": Model(200_000, 50, 4096, ANTHROPIC), + "claude-3-haiku-20240307": Model(200_000, 50, 4096, ANTHROPIC), + "claude-2.1": Model(200_000, 50, 4096, ANTHROPIC), + "claude-2.0": Model(100_000, 50, 4096, ANTHROPIC), + "claude-instant-1.2": Model(100_000, 50, 4096, ANTHROPIC), }