Skip to content

Commit

Permalink
Merge pull request #75 from BackofenLab/workMalek
Browse files Browse the repository at this point in the history
Chatbot history and removing of chat streaming
  • Loading branch information
Maluuck authored Oct 21, 2024
2 parents d8b48e6 + c7dc67c commit 51e6c41
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 103 deletions.
55 changes: 40 additions & 15 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os.path
import signal
import sys
import time
from multiprocessing import Process

import citation_graph
Expand All @@ -24,7 +23,7 @@
from werkzeug.middleware.proxy_fix import ProxyFix

app = Flask(__name__)

history = []
# ====================== Index page ======================

_SCRIPT_DIR = os.path.dirname(__file__)
Expand Down Expand Up @@ -139,6 +138,33 @@ def abstract_summary():
return Response(json.dumps(response), mimetype="application/json")


def make_prompt(message, proteins, funct_terms, abstract):
"""
Create a prompt for the chatbot.
Args:
message: Input message from user.
proteins: User selected proteins to be included in the prompt.
funct_terms: User selected functional terms to be included in the prompt.
abstract: User selected abstracts to be included in the prompt.
Returns:
prompt: The prompt to be used for response generation.
"""
pro = "use the following proteins:" if len(proteins) > 0 else ""
funct = "use the following functional terms:" if len(funct_terms) > 0 else ""
abstract_is = (
"use the following abstracts and state PMID if you use them for information:"
if len(abstract) > 0
else ""
)
proteins = proteins if len(proteins) > 0 else ""
funct_terms = funct_terms if len(funct_terms) > 0 else ""
abstract = abstract if len(abstract) > 0 else ""
prompt = (
f"{message} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}"
)
return prompt


@app.route("/api/subgraph/chatbot", methods=["POST"])
def chatbot_response():
message, background = (request.form.get("message"), request.form.get("background"))
Expand Down Expand Up @@ -169,24 +195,23 @@ def chatbot_response():
embedded_query, pmids_embeddings, 6
)
stopwatch.round("Vector search")
query = ""
abstracts = ""
for i in top_n_similiar:
query += f"PMID: {i} Abstract: {pmid_abstract[i]} "
answer = create_summary_RAG(
str(message),
abstracts += f"PMID: {i} Abstract: {pmid_abstract[i]} "
message = make_prompt(
message=message,
proteins=protein_list,
funct_terms=funct_terms_list,
abstract=query,
abstract=abstracts,
)

def generate():
# Yield each message from 'answer'
for i in answer:
yield json.dumps({"message": i["response"], "pmids": top_n_similiar})
time.sleep(0.1)

history.append({"role": "user", "content": message})
answer = create_summary_RAG(
history=history,
)
history.append(answer)
response = json.dumps({"message": answer["content"], "pmids": top_n_similiar})
stopwatch.round("Generating answer")
return Response(generate(), mimetype="application/json")
return Response(response, mimetype="application/json")


# ====================== Subgraph API ======================
Expand Down
22 changes: 3 additions & 19 deletions backend/src/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,6 @@ def get_response(prompt):
return [summary]


def create_summary_RAG(query, proteins, funct_terms, abstract):
pro = "use the following proteins:" if len(proteins) > 0 else ""
funct = "use the following functional terms:" if len(funct_terms) > 0 else ""
abstract_is = (
"use the following abstracts and state PMID if you use them for information:"
if len(abstract) > 0
else ""
)
proteins = proteins if len(proteins) > 0 else ""
funct_terms = funct_terms if len(funct_terms) > 0 else ""
abstract = abstract if len(abstract) > 0 else ""

def get_response(prompt):
response = ollama.generate(model="llama3.1", prompt=prompt, stream=True)
return response

prompt = f"{query} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}"
summary = get_response(prompt)
return summary
def create_summary_RAG(history):
response = ollama.chat(model="llama3.1", messages=history)
return response["message"]
96 changes: 27 additions & 69 deletions frontend/src/components/PersistentWindow.vue
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
{{ element.id }}
</span>
</div>
<div v-if="msg.ref">
<div v-if="msg.ref !== null">
<span class="small-tag blue" @click="searchRef(msg.ref)">
reference
</span>
Expand Down Expand Up @@ -74,7 +74,7 @@ export default {
api: {
chatbot: "api/subgraph/chatbot",
},
controller: null,
sourceToken: null,
};
},
computed: {
Expand Down Expand Up @@ -227,80 +227,38 @@ export default {
inputDiv.innerText = "";
}
},
async streamChatbotResponse(formData) {
let refData = null;
if (this.controller) {
this.controller.abort();
}
// Create a new AbortController instance
this.controller = new AbortController();
const signal = this.controller.signal; // Get the signal
const botMessage = {
sender: "Bot",
text: "Waiting for response...", // Initially empty, will be updated progressively
data: [...this.tags], // Add contextual data if needed
ref: null, // This will hold the pmids when received
};
this.messages.push(botMessage);
const response = await fetch(this.api.chatbot, {
method: "POST",
body: formData,
signal: signal,
});
if (!response.body) {
throw new Error("No response body");
}
const reader = response.body.getReader();
const decoder = new TextDecoder("utf-8");
let done = false;
let fullText = "";
// Index of the newly added Bot message
const botMessageIndex = this.messages.length - 1;
while (!done) {
const { value, done: readerDone } = await reader.read();
done = readerDone;
if (done) break;
// Decode the streamed data
const chunk = decoder.decode(value || new Uint8Array(), {
stream: !done,
});
// Parse the chunk as JSON to extract "messages" and "pmids"
let parsedChunk = JSON.parse(chunk);
// Check if it's a message part or pmids
if (parsedChunk.message) {
// Append message chunks to the fullText
fullText += parsedChunk.message;
// Update the bot message progressively
this.updateBotMessage(fullText, botMessageIndex);
}
if (parsedChunk.pmids) {
// If pmids are received, store them for later
refData = parsedChunk.pmids;
this.messages[botMessageIndex].ref = refData;
}
}
},
updateBotMessage(text, index) {
// Ensure we're updating the correct (newest) bot message by index
this.messages[index].text = text;
},
getAnswer(message) {
let com = this;
let formData = new FormData();
formData.append("message", message.text);
formData.append("background", JSON.stringify(message.data));
com.streamChatbotResponse(formData);
if (this.sourceToken) {
this.abort_chatbot();
}
this.messages.push({
sender: "Bot",
text: "Waiting for response...",
data: [...this.tags],
ref: null,
});
//POST request for generating pathways
com.sourceToken = this.axios.CancelToken.source();
com.axios
.post(com.api.chatbot, formData, {
cancelToken: com.sourceToken.token,
})
.then((response) => {
const botMessageIndex = this.messages.length - 1;
this.messages[botMessageIndex].ref = response.data.pmids;
this.messages[botMessageIndex].text = response.data.message;
this.sourceToken = null;
});
},
abort_chatbot() {
this.sourceToken.cancel("Request canceled");
},
closeWindow() {
this.windowCheck = false;
Expand Down

0 comments on commit 51e6c41

Please sign in to comment.