Skip to content

Commit

Permalink
url formatting hotfix
Browse files Browse the repository at this point in the history
  • Loading branch information
davidgxue committed Mar 6, 2024
1 parent df80ddc commit 899865b
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 16 deletions.
42 changes: 36 additions & 6 deletions api/ask_astro/chains/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@
MULTI_QUERY_RETRIEVER_TEMPERATURE,
)

with open("ask_astro/templates/combine_docs_chat_prompt.txt") as system_prompt_fd:
"""Load system prompt template from a file and structure it."""
messages = [
SystemMessagePromptTemplate.from_template(system_prompt_fd.read()),
with open("ask_astro/templates/combine_docs_sys_prompt_webapp.txt") as webapp_system_prompt_fd:
"""Load system prompt template for webapp messages"""
webapp_messages = [
SystemMessagePromptTemplate.from_template(webapp_system_prompt_fd.read()),
MessagesPlaceholder(variable_name="messages"),
HumanMessagePromptTemplate.from_template("{question}"),
]

with open("ask_astro/templates/combine_docs_sys_prompt_slack.txt") as slack_system_prompt_fd:
"""Load system prompt template for slack messages"""
slack_messages = [
SystemMessagePromptTemplate.from_template(slack_system_prompt_fd.read()),
MessagesPlaceholder(variable_name="messages"),
HumanMessagePromptTemplate.from_template("{question}"),
]
Expand Down Expand Up @@ -92,7 +100,29 @@
)

# Set up a ConversationalRetrievalChain to generate answers using the retriever.
answer_question_chain = ConversationalRetrievalChain(
webapp_answer_question_chain = ConversationalRetrievalChain(
retriever=llm_chain_filter_compression_retriever,
return_source_documents=True,
question_generator=LLMChain(
llm=AzureChatOpenAI(
**AzureOpenAIParams.us_east2,
deployment_name=CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_DEPLOYMENT_NAME,
temperature=CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_TEMPERATURE,
),
prompt=CONDENSE_QUESTION_PROMPT,
),
combine_docs_chain=load_qa_chain(
AzureChatOpenAI(
**AzureOpenAIParams.us_east2,
deployment_name=CONVERSATIONAL_RETRIEVAL_LOAD_QA_CHAIN_DEPLOYMENT_NAME,
temperature=CONVERSATIONAL_RETRIEVAL_LOAD_QA_CHAIN_TEMPERATURE,
),
chain_type="stuff",
prompt=ChatPromptTemplate.from_messages(webapp_messages),
),
)

slack_answer_question_chain = ConversationalRetrievalChain(
retriever=llm_chain_filter_compression_retriever,
return_source_documents=True,
question_generator=LLMChain(
Expand All @@ -110,6 +140,6 @@
temperature=CONVERSATIONAL_RETRIEVAL_LOAD_QA_CHAIN_TEMPERATURE,
),
chain_type="stuff",
prompt=ChatPromptTemplate.from_messages(messages),
prompt=ChatPromptTemplate.from_messages(slack_messages),
),
)
31 changes: 21 additions & 10 deletions api/ask_astro/services/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def answer_question(request: AskAstroRequest) -> None:
try:
from langchain import callbacks

from ask_astro.chains.answer_question import answer_question_chain
from ask_astro.chains.answer_question import slack_answer_question_chain, webapp_answer_question_chain

# First, mark the request as in_progress and add it to the database
request.status = "in_progress"
Expand All @@ -93,16 +93,27 @@ async def answer_question(request: AskAstroRequest) -> None:

# Run the question answering chain
with callbacks.collect_runs() as cb:
result = await asyncio.to_thread(
lambda: answer_question_chain(
{
"question": request.prompt,
"chat_history": [],
"messages": request.messages,
},
metadata={"request_id": str(request.uuid)},
if request.client == "slack":
result = await asyncio.to_thread(
lambda: slack_answer_question_chain(
{
"question": request.prompt,
"chat_history": request.messages,
},
metadata={"request_id": str(request.uuid), "client": str(request.client)},
)
)
else:
result = await asyncio.to_thread(
lambda: webapp_answer_question_chain(
{
"question": request.prompt,
"chat_history": [],
"messages": request.messages,
},
metadata={"request_id": str(request.uuid), "client": str(request.client)},
)
)
)
request.langchain_run_id = cb.traced_runs[0].id

logger.info("Question answering chain finished with result %s", result)
Expand Down
17 changes: 17 additions & 0 deletions api/ask_astro/templates/combine_docs_sys_prompt_webapp.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
You are Ask Astro, a friendy and helpful bot.
Only answer questions related to Astronomer, the Astro platform and Apache Airflow.
If the question relates to pricing, licensing, or commercial usage, ask the user to contact support at www.astronomer.io/contact.
If you don't know the answer, just say that you don't know and ask the user to contact support, don't try to make up an answer.
If the supplied context below does not have sufficient information to help answer the question, make a note when answering to let the user know that the answer may contain false information and the user should contact support to verify.
Be concise and precise in your answers and do not apologize.
Format your response using Markdown syntax.
Surround text with SINGLE * to format it in bold or provide emphasis. Examples: GOOD: *This is bold!*. BAD: **This is bold!**.
Support text with _ to format it in italic. Example: _This is italic._
Use the • character for unnumbered lists.
Use the ` character to surround inline code. Example: This is a sentence with some `inline *code*` in it.
Use ``` to surround multi-line code blocks. Do not specify a language in code blocks. Examples: GOOD: ```This is a code block\nAnd it is multi-line``` BAD: ```python print("Hello world!")```.
Format links using this format: [Text to display](URL). Examples: GOOD: [This message **is** a link](https://www.example.com). BAD: <https://www.example.com|This message **is** a link>.
12 character words that start with "<@U" and end with ">" are usernames. Example: <@U024BE7LH>.
Use the following pieces of context to answer the users question.
----------------
{context}

0 comments on commit 899865b

Please sign in to comment.