diff --git a/backend/src/main.py b/backend/src/main.py index cba7305e..6b891caf 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -5,7 +5,6 @@ import os.path import signal import sys -import time from multiprocessing import Process import citation_graph @@ -24,7 +23,7 @@ from werkzeug.middleware.proxy_fix import ProxyFix app = Flask(__name__) - +history = [] # ====================== Index page ====================== _SCRIPT_DIR = os.path.dirname(__file__) @@ -139,6 +138,33 @@ def abstract_summary(): return Response(json.dumps(response), mimetype="application/json") +def make_prompt(message, proteins, funct_terms, abstract): + """ + Create a prompt for the chatbot. + Args: + message: Input message from user. + proteins: User selected proteins to be included in the prompt. + funct_terms: User selected functional terms to be included in the prompt. + abstract: User selected abstracts to be included in the prompt. + Returns: + prompt: The prompt to be used for response generation. + """ + pro = "use the following proteins:" if len(proteins) > 0 else "" + funct = "use the following functional terms:" if len(funct_terms) > 0 else "" + abstract_is = ( + "use the following abstracts and state PMID if you use them for information:" + if len(abstract) > 0 + else "" + ) + proteins = proteins if len(proteins) > 0 else "" + funct_terms = funct_terms if len(funct_terms) > 0 else "" + abstract = abstract if len(abstract) > 0 else "" + prompt = ( + f"{message} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}" + ) + return prompt + + @app.route("/api/subgraph/chatbot", methods=["POST"]) def chatbot_response(): message, background = (request.form.get("message"), request.form.get("background")) @@ -169,24 +195,23 @@ def chatbot_response(): embedded_query, pmids_embeddings, 6 ) stopwatch.round("Vector search") - query = "" + abstracts = "" for i in top_n_similiar: - query += f"PMID: {i} Abstract: {pmid_abstract[i]} " - answer = create_summary_RAG( - str(message), + abstracts += f"PMID: {i} Abstract: {pmid_abstract[i]} " + message = make_prompt( + message=message, proteins=protein_list, funct_terms=funct_terms_list, - abstract=query, + abstract=abstracts, ) - - def generate(): - # Yield each message from 'answer' - for i in answer: - yield json.dumps({"message": i["response"], "pmids": top_n_similiar}) - time.sleep(0.1) - + history.append({"role": "user", "content": message}) + answer = create_summary_RAG( + history=history, + ) + history.append(answer) + response = json.dumps({"message": answer["content"], "pmids": top_n_similiar}) stopwatch.round("Generating answer") - return Response(generate(), mimetype="application/json") + return Response(response, mimetype="application/json") # ====================== Subgraph API ====================== diff --git a/backend/src/summarization/model.py b/backend/src/summarization/model.py index daba2908..7282ff51 100644 --- a/backend/src/summarization/model.py +++ b/backend/src/summarization/model.py @@ -70,22 +70,6 @@ def get_response(prompt): return [summary] -def create_summary_RAG(query, proteins, funct_terms, abstract): - pro = "use the following proteins:" if len(proteins) > 0 else "" - funct = "use the following functional terms:" if len(funct_terms) > 0 else "" - abstract_is = ( - "use the following abstracts and state PMID if you use them for information:" - if len(abstract) > 0 - else "" - ) - proteins = proteins if len(proteins) > 0 else "" - funct_terms = funct_terms if len(funct_terms) > 0 else "" - abstract = abstract if len(abstract) > 0 else "" - - def get_response(prompt): - response = ollama.generate(model="llama3.1", prompt=prompt, stream=True) - return response - - prompt = f"{query} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}" - summary = get_response(prompt) - return summary +def create_summary_RAG(history): + response = ollama.chat(model="llama3.1", messages=history) + return response["message"] diff --git a/frontend/src/components/PersistentWindow.vue b/frontend/src/components/PersistentWindow.vue index 3418f35c..14bab52a 100644 --- a/frontend/src/components/PersistentWindow.vue +++ b/frontend/src/components/PersistentWindow.vue @@ -28,7 +28,7 @@ {{ element.id }} -
+
reference @@ -74,7 +74,7 @@ export default { api: { chatbot: "api/subgraph/chatbot", }, - controller: null, + sourceToken: null, }; }, computed: { @@ -227,80 +227,38 @@ export default { inputDiv.innerText = ""; } }, - async streamChatbotResponse(formData) { - let refData = null; - if (this.controller) { - this.controller.abort(); - } - - // Create a new AbortController instance - this.controller = new AbortController(); - const signal = this.controller.signal; // Get the signal - - const botMessage = { - sender: "Bot", - text: "Waiting for response...", // Initially empty, will be updated progressively - data: [...this.tags], // Add contextual data if needed - ref: null, // This will hold the pmids when received - }; - this.messages.push(botMessage); - - const response = await fetch(this.api.chatbot, { - method: "POST", - body: formData, - signal: signal, - }); - - if (!response.body) { - throw new Error("No response body"); - } - - const reader = response.body.getReader(); - const decoder = new TextDecoder("utf-8"); - let done = false; - let fullText = ""; - - // Index of the newly added Bot message - const botMessageIndex = this.messages.length - 1; - - while (!done) { - const { value, done: readerDone } = await reader.read(); - done = readerDone; - if (done) break; - // Decode the streamed data - const chunk = decoder.decode(value || new Uint8Array(), { - stream: !done, - }); - // Parse the chunk as JSON to extract "messages" and "pmids" - let parsedChunk = JSON.parse(chunk); - // Check if it's a message part or pmids - if (parsedChunk.message) { - // Append message chunks to the fullText - fullText += parsedChunk.message; - - // Update the bot message progressively - this.updateBotMessage(fullText, botMessageIndex); - } - - if (parsedChunk.pmids) { - // If pmids are received, store them for later - refData = parsedChunk.pmids; - this.messages[botMessageIndex].ref = refData; - } - } - }, - updateBotMessage(text, index) { - // Ensure we're updating the correct (newest) bot message by index - this.messages[index].text = text; - }, getAnswer(message) { let com = this; let formData = new FormData(); formData.append("message", message.text); formData.append("background", JSON.stringify(message.data)); - com.streamChatbotResponse(formData); + if (this.sourceToken) { + this.abort_chatbot(); + } + + this.messages.push({ + sender: "Bot", + text: "Waiting for response...", + data: [...this.tags], + ref: null, + }); + //POST request for generating pathways + com.sourceToken = this.axios.CancelToken.source(); + com.axios + .post(com.api.chatbot, formData, { + cancelToken: com.sourceToken.token, + }) + .then((response) => { + const botMessageIndex = this.messages.length - 1; + this.messages[botMessageIndex].ref = response.data.pmids; + this.messages[botMessageIndex].text = response.data.message; + this.sourceToken = null; + }); + }, + abort_chatbot() { + this.sourceToken.cancel("Request canceled"); }, closeWindow() { this.windowCheck = false;