Skip to content

Commit

Permalink
NN-565 Improved prompts and code structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Maluuck committed Oct 23, 2024
1 parent c7dc67c commit e85e2ac
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 64 deletions.
91 changes: 32 additions & 59 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from dotenv import load_dotenv
from flask import Flask, Response, request, send_from_directory
from summarization import article_graph as summarization
from summarization.model import create_summary_RAG, overall_summary
from summarization.chat_bot import chat, make_prompt, populate, summarize
from summarization.model import overall_summary
from util.stopwatch import Stopwatch
from werkzeug.middleware.proxy_fix import ProxyFix

Expand Down Expand Up @@ -138,76 +139,48 @@ def abstract_summary():
return Response(json.dumps(response), mimetype="application/json")


def make_prompt(message, proteins, funct_terms, abstract):
"""
Create a prompt for the chatbot.
Args:
message: Input message from user.
proteins: User selected proteins to be included in the prompt.
funct_terms: User selected functional terms to be included in the prompt.
abstract: User selected abstracts to be included in the prompt.
Returns:
prompt: The prompt to be used for response generation.
"""
pro = "use the following proteins:" if len(proteins) > 0 else ""
funct = "use the following functional terms:" if len(funct_terms) > 0 else ""
abstract_is = (
"use the following abstracts and state PMID if you use them for information:"
if len(abstract) > 0
else ""
)
proteins = proteins if len(proteins) > 0 else ""
funct_terms = funct_terms if len(funct_terms) > 0 else ""
abstract = abstract if len(abstract) > 0 else ""
prompt = (
f"{message} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}"
)
return prompt


@app.route("/api/subgraph/chatbot", methods=["POST"])
def chatbot_response():
message, background = (request.form.get("message"), request.form.get("background"))
data = json.loads(background)
pmids = []
pmid_abstract = {}
protein_list = []
funct_terms_list = []
for item in data:
mode = item["mode"]
entries = [item["data"]] if item["type"] != "subset" else item["data"]
if mode == "citation":
pmids.extend([j["attributes"]["Name"] for j in entries])
pmid_abstract.update(
{j["attributes"]["Name"]: j["attributes"]["Abstract"] for j in entries}
)
elif mode == "protein":
protein_list.extend([j["attributes"]["Name"] for j in entries])
else:
funct_terms_list.extend([j["label"] for j in entries])
stopwatch = Stopwatch()
driver = database.get_driver()
pmids_embeddings = queries.fetch_vector_embeddings(driver=driver, pmids=pmids)
stopwatch.round("Fetching embeddings")
embedded_query = summarization.generate_embedding(str(message))
stopwatch.round("Embedding query")
top_n_similiar = summarization.top_n_similar_vectors(
embedded_query, pmids_embeddings, 6
)
stopwatch.round("Vector search")
abstracts = ""
for i in top_n_similiar:
abstracts += f"PMID: {i} Abstract: {pmid_abstract[i]} "
# Bring background data into usable format
pmids, pmid_abstract, protein_list, funct_terms_list = populate(data)
abstracts = []
top_n_similiar = []
# Case abstracts are selected
if len(pmids) > 0:
pmids_embeddings = queries.fetch_vector_embeddings(driver=driver, pmids=pmids)
stopwatch.round("Fetching embeddings")
embedded_query = summarization.generate_embedding(str(message))
stopwatch.round("Embedding query")
top_n_similiar = summarization.top_n_similar_vectors(
embedded_query, pmids_embeddings, 6
)
unsummarized = (
[
[pmid_abstract[i] for i in top_n_similiar[:3]],
[pmid_abstract[i] for i in top_n_similiar[3:]],
]
if len(top_n_similiar) > 3
else [pmid_abstract[i] for i in top_n_similiar]
)
summarized = summarize(unsummarized, protein_list)
stopwatch.round("Vector search")
abstracts = [
f"Abstract {num+1} with PMID {i}: {summarized[num]}"
for num, i in enumerate(top_n_similiar)
]
protein_list = []
message = make_prompt(
message=message,
proteins=protein_list,
funct_terms=funct_terms_list,
proteins=protein_list,
abstract=abstracts,
)
history.append({"role": "user", "content": message})
answer = create_summary_RAG(
history=history,
)
answer = chat(history=history)
history.append(answer)
response = json.dumps({"message": answer["content"], "pmids": top_n_similiar})
stopwatch.round("Generating answer")
Expand Down
103 changes: 103 additions & 0 deletions backend/src/summarization/chat_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import re
from ast import literal_eval

import ollama


def make_prompt(message, proteins, funct_terms, abstract):
"""
Create a prompt for the chatbot.
Args:
message: Input message from user.
funct_terms: User selected functional terms to be included in the prompt.
abstract: User selected abstracts to be included in the prompt.
Returns:
prompt: The prompt to be used for response generation.
"""
functional_term_background = (
f"Functional terms: {funct_terms} \n" if len(funct_terms) > 0 else ""
)
protein_background = f"Proteins: {proteins} \n" if len(proteins) > 0 else ""
abstracts = f"Scientific Abstracts: {abstract} \n" if len(abstract) > 0 else ""
functional_term_prompt = (
"with the background of the provided functional terms, "
if len(funct_terms) > 0
else ""
)
protein_prompt = (
f"with the background of the provided proteins, " if len(proteins) > 0 else ""
)
abstract_prompt = (
f"use the information from the {len(abstract)} provided abstracts and state the pmids if used."
if len(abstract) > 0
else ""
)

final_prompt = f"{protein_background}{functional_term_background}{abstracts}{message}{protein_prompt}{functional_term_prompt}{abstract_prompt}"
return final_prompt


def populate(data):
pmids = []
pmid_abstract = {}
protein_list = []
funct_terms_list = []
for item in data:
data_type = item["type"]
entries = [item["data"]] if item["type"] != "subset" else item["data"]
if data_type == "subset":
pmids.extend([j["attributes"]["Name"] for j in entries])
pmid_abstract.update(
{
j["attributes"]["Name"]: j["attributes"]["Abstract"].replace(
"'", ""
)
for j in entries
}
)
elif data_type == "protein":
protein_list.extend([j["attributes"]["Name"] for j in entries])
else:
funct_terms_list.extend([j["name"] for j in entries])
return pmids, pmid_abstract, protein_list, funct_terms_list


def chat(history, model="llama3.1"):
"""
Generate a reply from the AI model, (chat history taken into consideration).
Args:
model: AI model to be used, defaults to llama3.1
history: Chat history needed for ai memory, has format: {"role": <user, assistant or system>, "content": <message>}
Returns:
response["message"]: reply of the model
"""
response = ollama.chat(model=model, messages=history, options={"temperature": 0.0})
return response["message"]


def summarize(input_text, proteins):
"""
Summarize abstracts obtained by Graph_RAG.
Args:
input_text: inputs to be summarized, format is list of lists
proteins: proteins to be focused on when generating the summary
Returns:
flattened_response: List of the summarized abstracts
"""
raw_response = [
ollama.generate(
"llama3.1",
f"{i} summarize with a focus on {proteins} each one of the {len(i)} abstracts in 30 words into a list i.e format ['summary 1', .. , 'summary n'] dont say anything like here are the summaries or so, make sure it has the correct format for python",
)["response"]
for i in input_text
]
cleaned_response = [
literal_eval(re.sub(r"(?<![\[\],\s])'(?![\[\],])", "", i.replace("\n", "")))
for i in raw_response
]
flattened_response = [i for j in cleaned_response for i in j]
return flattened_response
5 changes: 0 additions & 5 deletions backend/src/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,3 @@ def get_response(prompt):
summary = get_response(prompt)

return [summary]


def create_summary_RAG(history):
response = ollama.chat(model="llama3.1", messages=history)
return response["message"]

0 comments on commit e85e2ac

Please sign in to comment.