diff --git a/backend/src/main.py b/backend/src/main.py index 6b891caf..9d5f507c 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -18,7 +18,8 @@ from dotenv import load_dotenv from flask import Flask, Response, request, send_from_directory from summarization import article_graph as summarization -from summarization.model import create_summary_RAG, overall_summary +from summarization.chat_bot import chat, make_prompt, populate, summarize +from summarization.model import overall_summary from util.stopwatch import Stopwatch from werkzeug.middleware.proxy_fix import ProxyFix @@ -138,76 +139,48 @@ def abstract_summary(): return Response(json.dumps(response), mimetype="application/json") -def make_prompt(message, proteins, funct_terms, abstract): - """ - Create a prompt for the chatbot. - Args: - message: Input message from user. - proteins: User selected proteins to be included in the prompt. - funct_terms: User selected functional terms to be included in the prompt. - abstract: User selected abstracts to be included in the prompt. - Returns: - prompt: The prompt to be used for response generation. - """ - pro = "use the following proteins:" if len(proteins) > 0 else "" - funct = "use the following functional terms:" if len(funct_terms) > 0 else "" - abstract_is = ( - "use the following abstracts and state PMID if you use them for information:" - if len(abstract) > 0 - else "" - ) - proteins = proteins if len(proteins) > 0 else "" - funct_terms = funct_terms if len(funct_terms) > 0 else "" - abstract = abstract if len(abstract) > 0 else "" - prompt = ( - f"{message} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}" - ) - return prompt - - @app.route("/api/subgraph/chatbot", methods=["POST"]) def chatbot_response(): message, background = (request.form.get("message"), request.form.get("background")) data = json.loads(background) - pmids = [] - pmid_abstract = {} - protein_list = [] - funct_terms_list = [] - for item in data: - mode = item["mode"] - entries = [item["data"]] if item["type"] != "subset" else item["data"] - if mode == "citation": - pmids.extend([j["attributes"]["Name"] for j in entries]) - pmid_abstract.update( - {j["attributes"]["Name"]: j["attributes"]["Abstract"] for j in entries} - ) - elif mode == "protein": - protein_list.extend([j["attributes"]["Name"] for j in entries]) - else: - funct_terms_list.extend([j["label"] for j in entries]) stopwatch = Stopwatch() driver = database.get_driver() - pmids_embeddings = queries.fetch_vector_embeddings(driver=driver, pmids=pmids) - stopwatch.round("Fetching embeddings") - embedded_query = summarization.generate_embedding(str(message)) - stopwatch.round("Embedding query") - top_n_similiar = summarization.top_n_similar_vectors( - embedded_query, pmids_embeddings, 6 - ) - stopwatch.round("Vector search") - abstracts = "" - for i in top_n_similiar: - abstracts += f"PMID: {i} Abstract: {pmid_abstract[i]} " + # Bring background data into usable format + pmids, pmid_abstract, protein_list, funct_terms_list = populate(data) + abstracts = [] + top_n_similiar = [] + # Case abstracts are selected + if len(pmids) > 0: + pmids_embeddings = queries.fetch_vector_embeddings(driver=driver, pmids=pmids) + stopwatch.round("Fetching embeddings") + embedded_query = summarization.generate_embedding(str(message)) + stopwatch.round("Embedding query") + top_n_similiar = summarization.top_n_similar_vectors( + embedded_query, pmids_embeddings, 6 + ) + unsummarized = ( + [ + [pmid_abstract[i] for i in top_n_similiar[:3]], + [pmid_abstract[i] for i in top_n_similiar[3:]], + ] + if len(top_n_similiar) > 3 + else [pmid_abstract[i] for i in top_n_similiar] + ) + summarized = summarize(unsummarized, protein_list) + stopwatch.round("Vector search") + abstracts = [ + f"Abstract {num+1} with PMID {i}: {summarized[num]}" + for num, i in enumerate(top_n_similiar) + ] + protein_list = [] message = make_prompt( message=message, - proteins=protein_list, funct_terms=funct_terms_list, + proteins=protein_list, abstract=abstracts, ) history.append({"role": "user", "content": message}) - answer = create_summary_RAG( - history=history, - ) + answer = chat(history=history) history.append(answer) response = json.dumps({"message": answer["content"], "pmids": top_n_similiar}) stopwatch.round("Generating answer") diff --git a/backend/src/summarization/chat_bot.py b/backend/src/summarization/chat_bot.py new file mode 100644 index 00000000..4fd6ad8f --- /dev/null +++ b/backend/src/summarization/chat_bot.py @@ -0,0 +1,103 @@ +import re +from ast import literal_eval + +import ollama + + +def make_prompt(message, proteins, funct_terms, abstract): + """ + Create a prompt for the chatbot. + Args: + message: Input message from user. + funct_terms: User selected functional terms to be included in the prompt. + abstract: User selected abstracts to be included in the prompt. + Returns: + prompt: The prompt to be used for response generation. + """ + functional_term_background = ( + f"Functional terms: {funct_terms} \n" if len(funct_terms) > 0 else "" + ) + protein_background = f"Proteins: {proteins} \n" if len(proteins) > 0 else "" + abstracts = f"Scientific Abstracts: {abstract} \n" if len(abstract) > 0 else "" + functional_term_prompt = ( + "with the background of the provided functional terms, " + if len(funct_terms) > 0 + else "" + ) + protein_prompt = ( + f"with the background of the provided proteins, " if len(proteins) > 0 else "" + ) + abstract_prompt = ( + f"use the information from the {len(abstract)} provided abstracts and state the pmids if used." + if len(abstract) > 0 + else "" + ) + + final_prompt = f"{protein_background}{functional_term_background}{abstracts}{message}{protein_prompt}{functional_term_prompt}{abstract_prompt}" + return final_prompt + + +def populate(data): + pmids = [] + pmid_abstract = {} + protein_list = [] + funct_terms_list = [] + for item in data: + data_type = item["type"] + entries = [item["data"]] if item["type"] != "subset" else item["data"] + if data_type == "subset": + pmids.extend([j["attributes"]["Name"] for j in entries]) + pmid_abstract.update( + { + j["attributes"]["Name"]: j["attributes"]["Abstract"].replace( + "'", "" + ) + for j in entries + } + ) + elif data_type == "protein": + protein_list.extend([j["attributes"]["Name"] for j in entries]) + else: + funct_terms_list.extend([j["name"] for j in entries]) + return pmids, pmid_abstract, protein_list, funct_terms_list + + +def chat(history, model="llama3.1"): + """ + Generate a reply from the AI model, (chat history taken into consideration). + + Args: + model: AI model to be used, defaults to llama3.1 + history: Chat history needed for ai memory, has format: {"role": , "content": } + + Returns: + response["message"]: reply of the model + """ + response = ollama.chat(model=model, messages=history, options={"temperature": 0.0}) + return response["message"] + + +def summarize(input_text, proteins): + """ + Summarize abstracts obtained by Graph_RAG. + + Args: + input_text: inputs to be summarized, format is list of lists + proteins: proteins to be focused on when generating the summary + + Returns: + flattened_response: List of the summarized abstracts + """ + raw_response = [ + ollama.generate( + "llama3.1", + f"{i} summarize with a focus on {proteins} each one of the {len(i)} abstracts in 30 words into a list i.e format ['summary 1', .. , 'summary n'] dont say anything like here are the summaries or so, make sure it has the correct format for python", + )["response"] + for i in input_text + ] + cleaned_response = [ + literal_eval(re.sub(r"(?