Skip to content

Commit

Permalink
update models
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Dec 8, 2024
1 parent 86f6751 commit 14af5c8
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
45 changes: 38 additions & 7 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any, Callable, Dict, List

from langchain.chains import LLMChain, OpenAIModerationChain
Expand All @@ -14,13 +15,19 @@
from langchain.schema import AIMessage, BaseMessage, HumanMessage, PromptValue, SystemMessage

from stampy_chat.env import OPENAI_API_KEY, ANTHROPIC_API_KEY, LANGCHAIN_API_KEY, LANGCHAIN_TRACING_V2, SUMMARY_MODEL
from stampy_chat.settings import Settings, MODELS, OPENAI, ANTRHROPIC
from stampy_chat.settings import Settings, MODELS, OPENAI, ANTHROPIC
from stampy_chat.callbacks import StampyCallbackHandler, BroadcastCallbackHandler, LoggerCallbackHandler
from stampy_chat.followups import StampyChain
from stampy_chat.citations import make_example_selector

from langsmith import Client

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

os.environ["TOKENIZERS_PARALLELISM"] = "false"


if LANGCHAIN_TRACING_V2 == "true":
if not LANGCHAIN_API_KEY:
raise Exception("Langsmith tracing is enabled but no api key was provided. Please set LANGCHAIN_API_KEY in the .env file.")
Expand Down Expand Up @@ -192,7 +199,7 @@ def get_model(**kwargs):
model = MODELS.get(kwargs.get('model'))
if not model:
raise ValueError("No model provided")
if model.publisher == ANTRHROPIC:
if model.publisher == ANTHROPIC:
return ChatAnthropicWrapper(anthropic_api_key=ANTHROPIC_API_KEY, **kwargs)
if model.publisher == OPENAI:
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, **kwargs)
Expand Down Expand Up @@ -245,7 +252,7 @@ def make_history_summary(settings):
def make_prompt(settings, chat_model, callbacks):
"""Create a proper prompt object will all the nessesery steps."""
# 1. Create the context prompt from items fetched from pinecone
context_template = "[{{reference}}] {{title}} {{authors | join(', ')}} - {{date_published}} {{text}}"
context_template = "\n\n[{{reference}}] {{title}} {{authors | join(', ')}} - {{date_published}} {{text}}\n\n"
context_prompt = MessageBufferPromptTemplate(
example_selector=make_example_selector(k=settings.topKBlocks, callbacks=callbacks),
example_prompt=ChatPromptTemplate.from_template(context_template, template_format="jinja2"),
Expand All @@ -261,7 +268,10 @@ def make_prompt(settings, chat_model, callbacks):
query_prompt = ChatPromptTemplate.from_messages(
[
HumanMessage(content=settings.question_prompt),
HumanMessagePromptTemplate.from_template(template='Q: {history_summary}: {query}', role='user'),
HumanMessagePromptTemplate.from_template(
template='{history_summary}{delimiter}{query}',
partial_variables={"delimiter": lambda **kwargs: ": " if kwargs.get("history_summary") else ""}
),
]
)

Expand Down Expand Up @@ -334,16 +344,37 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin
model=settings.completions
)

chain = make_history_summary(settings) | LLMChain(
history_summary_chain = make_history_summary(settings)

if history:
history_summary_result = history_summary_chain.invoke({"query": query, 'history': history})
history_summary = history_summary_result.get('history_summary', '')
else:
history_summary = ''

delimiter = ": " if history_summary else ""

llm_chain = LLMChain(
llm=chat_model,
verbose=False,
prompt=make_prompt(settings, chat_model, callbacks),
memory=make_memory(settings, history, callbacks)
)

chain = history_summary_chain | llm_chain
if followups:
chain = chain | StampyChain(callbacks=callbacks)
result = chain.invoke({"query": query, 'history': history}, {'callbacks': []})

chain_input = {
"query": query,
'history': history,
'history_summary': history_summary,
'delimiter': delimiter,
}

result = chain.invoke(chain_input)

if callback:
callback({'state': 'done'})
callback(None) # make sure the callback handler know that things have ended
callback(None)
return result
4 changes: 2 additions & 2 deletions api/src/stampy_chat/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

### Models ###
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-ada-002")
SUMMARY_MODEL = os.environ.get("SUMMARY_MODEL", "claude-3-5-sonnet-20240620")
COMPLETIONS_MODEL = os.environ.get("COMPLETIONS_MODEL", "claude-3-5-sonnet-20240620")
SUMMARY_MODEL = os.environ.get("SUMMARY_MODEL", "claude-3-5-sonnet-latest")
COMPLETIONS_MODEL = os.environ.get("COMPLETIONS_MODEL", "claude-3-5-sonnet-latest")

### Pinecone ###
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
Expand Down
19 changes: 11 additions & 8 deletions api/src/stampy_chat/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,23 @@
'modes': PROMPT_MODES,
}
OPENAI = 'openai'
ANTRHROPIC = 'anthropic'
ANTHROPIC = 'anthropic'
MODELS = {
'gpt-3.5-turbo': Model(4097, 10, 4096, OPENAI),
'gpt-3.5-turbo-16k': Model(16385, 30, 4096, OPENAI),
'gpt-4': Model(8192, 20, 4096, OPENAI),
"gpt-4-turbo-preview": Model(128000, 50, 4096, OPENAI),
"gpt-4o": Model(128000, 50, 4096, OPENAI),
"claude-3-opus-20240229": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-3-5-sonnet-20240620": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-3-sonnet-20240229": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-3-haiku-20240307": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-2.1": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-2.0": Model(100_000, 50, 4096, ANTRHROPIC),
"claude-instant-1.2": Model(100_000, 50, 4096, ANTRHROPIC),
"gpt-4o-mini": Model(128000, 50, 4096, OPENAI),
"claude-3-opus-20240229": Model(200_000, 50, 4096, ANTHROPIC),
"claude-3-5-sonnet-20240620": Model(200_000, 50, 4096, ANTHROPIC),
"claude-3-5-sonnet-20241022": Model(200_000, 50, 4096, ANTHROPIC),
"claude-3-5-sonnet-latest": Model(200_000, 50, 4096, ANTHROPIC),
"claude-3-sonnet-20240229": Model(200_000, 50, 4096, ANTHROPIC),
"claude-3-haiku-20240307": Model(200_000, 50, 4096, ANTHROPIC),
"claude-2.1": Model(200_000, 50, 4096, ANTHROPIC),
"claude-2.0": Model(100_000, 50, 4096, ANTHROPIC),
"claude-instant-1.2": Model(100_000, 50, 4096, ANTHROPIC),
}


Expand Down

0 comments on commit 14af5c8

Please sign in to comment.