diff --git a/backend/src/main.py b/backend/src/main.py index cba7305e..6b891caf 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -5,7 +5,6 @@ import os.path import signal import sys -import time from multiprocessing import Process import citation_graph @@ -24,7 +23,7 @@ from werkzeug.middleware.proxy_fix import ProxyFix app = Flask(__name__) - +history = [] # ====================== Index page ====================== _SCRIPT_DIR = os.path.dirname(__file__) @@ -139,6 +138,33 @@ def abstract_summary(): return Response(json.dumps(response), mimetype="application/json") +def make_prompt(message, proteins, funct_terms, abstract): + """ + Create a prompt for the chatbot. + Args: + message: Input message from user. + proteins: User selected proteins to be included in the prompt. + funct_terms: User selected functional terms to be included in the prompt. + abstract: User selected abstracts to be included in the prompt. + Returns: + prompt: The prompt to be used for response generation. + """ + pro = "use the following proteins:" if len(proteins) > 0 else "" + funct = "use the following functional terms:" if len(funct_terms) > 0 else "" + abstract_is = ( + "use the following abstracts and state PMID if you use them for information:" + if len(abstract) > 0 + else "" + ) + proteins = proteins if len(proteins) > 0 else "" + funct_terms = funct_terms if len(funct_terms) > 0 else "" + abstract = abstract if len(abstract) > 0 else "" + prompt = ( + f"{message} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}" + ) + return prompt + + @app.route("/api/subgraph/chatbot", methods=["POST"]) def chatbot_response(): message, background = (request.form.get("message"), request.form.get("background")) @@ -169,24 +195,23 @@ def chatbot_response(): embedded_query, pmids_embeddings, 6 ) stopwatch.round("Vector search") - query = "" + abstracts = "" for i in top_n_similiar: - query += f"PMID: {i} Abstract: {pmid_abstract[i]} " - answer = create_summary_RAG( - str(message), + abstracts += f"PMID: {i} Abstract: {pmid_abstract[i]} " + message = make_prompt( + message=message, proteins=protein_list, funct_terms=funct_terms_list, - abstract=query, + abstract=abstracts, ) - - def generate(): - # Yield each message from 'answer' - for i in answer: - yield json.dumps({"message": i["response"], "pmids": top_n_similiar}) - time.sleep(0.1) - + history.append({"role": "user", "content": message}) + answer = create_summary_RAG( + history=history, + ) + history.append(answer) + response = json.dumps({"message": answer["content"], "pmids": top_n_similiar}) stopwatch.round("Generating answer") - return Response(generate(), mimetype="application/json") + return Response(response, mimetype="application/json") # ====================== Subgraph API ====================== diff --git a/backend/src/summarization/model.py b/backend/src/summarization/model.py index daba2908..7282ff51 100644 --- a/backend/src/summarization/model.py +++ b/backend/src/summarization/model.py @@ -70,22 +70,6 @@ def get_response(prompt): return [summary] -def create_summary_RAG(query, proteins, funct_terms, abstract): - pro = "use the following proteins:" if len(proteins) > 0 else "" - funct = "use the following functional terms:" if len(funct_terms) > 0 else "" - abstract_is = ( - "use the following abstracts and state PMID if you use them for information:" - if len(abstract) > 0 - else "" - ) - proteins = proteins if len(proteins) > 0 else "" - funct_terms = funct_terms if len(funct_terms) > 0 else "" - abstract = abstract if len(abstract) > 0 else "" - - def get_response(prompt): - response = ollama.generate(model="llama3.1", prompt=prompt, stream=True) - return response - - prompt = f"{query} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}" - summary = get_response(prompt) - return summary +def create_summary_RAG(history): + response = ollama.chat(model="llama3.1", messages=history) + return response["message"] diff --git a/frontend/src/components/PersistentWindow.vue b/frontend/src/components/PersistentWindow.vue index 3418f35c..14bab52a 100644 --- a/frontend/src/components/PersistentWindow.vue +++ b/frontend/src/components/PersistentWindow.vue @@ -28,7 +28,7 @@ {{ element.id }} -