Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chatbot history and removing of chat streaming #75

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os.path
import signal
import sys
import time
from multiprocessing import Process

import citation_graph
Expand All @@ -24,7 +23,7 @@
from werkzeug.middleware.proxy_fix import ProxyFix

app = Flask(__name__)

history = []
# ====================== Index page ======================

_SCRIPT_DIR = os.path.dirname(__file__)
Expand Down Expand Up @@ -139,6 +138,33 @@ def abstract_summary():
return Response(json.dumps(response), mimetype="application/json")


def make_prompt(message, proteins, funct_terms, abstract):
"""
Create a prompt for the chatbot.
Args:
message: Input message from user.
proteins: User selected proteins to be included in the prompt.
funct_terms: User selected functional terms to be included in the prompt.
abstract: User selected abstracts to be included in the prompt.
Returns:
prompt: The prompt to be used for response generation.
"""
pro = "use the following proteins:" if len(proteins) > 0 else ""
funct = "use the following functional terms:" if len(funct_terms) > 0 else ""
abstract_is = (
"use the following abstracts and state PMID if you use them for information:"
if len(abstract) > 0
else ""
)
proteins = proteins if len(proteins) > 0 else ""
funct_terms = funct_terms if len(funct_terms) > 0 else ""
abstract = abstract if len(abstract) > 0 else ""
prompt = (
f"{message} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}"
)
return prompt


@app.route("/api/subgraph/chatbot", methods=["POST"])
def chatbot_response():
message, background = (request.form.get("message"), request.form.get("background"))
Expand Down Expand Up @@ -169,24 +195,23 @@ def chatbot_response():
embedded_query, pmids_embeddings, 6
)
stopwatch.round("Vector search")
query = ""
abstracts = ""
for i in top_n_similiar:
query += f"PMID: {i} Abstract: {pmid_abstract[i]} "
answer = create_summary_RAG(
str(message),
abstracts += f"PMID: {i} Abstract: {pmid_abstract[i]} "
message = make_prompt(
message=message,
proteins=protein_list,
funct_terms=funct_terms_list,
abstract=query,
abstract=abstracts,
)

def generate():
# Yield each message from 'answer'
for i in answer:
yield json.dumps({"message": i["response"], "pmids": top_n_similiar})
time.sleep(0.1)

history.append({"role": "user", "content": message})
answer = create_summary_RAG(
history=history,
)
history.append(answer)
response = json.dumps({"message": answer["content"], "pmids": top_n_similiar})
stopwatch.round("Generating answer")
return Response(generate(), mimetype="application/json")
return Response(response, mimetype="application/json")


# ====================== Subgraph API ======================
Expand Down
22 changes: 3 additions & 19 deletions backend/src/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,6 @@ def get_response(prompt):
return [summary]


def create_summary_RAG(query, proteins, funct_terms, abstract):
pro = "use the following proteins:" if len(proteins) > 0 else ""
funct = "use the following functional terms:" if len(funct_terms) > 0 else ""
abstract_is = (
"use the following abstracts and state PMID if you use them for information:"
if len(abstract) > 0
else ""
)
proteins = proteins if len(proteins) > 0 else ""
funct_terms = funct_terms if len(funct_terms) > 0 else ""
abstract = abstract if len(abstract) > 0 else ""

def get_response(prompt):
response = ollama.generate(model="llama3.1", prompt=prompt, stream=True)
return response

prompt = f"{query} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}"
summary = get_response(prompt)
return summary
def create_summary_RAG(history):
response = ollama.chat(model="llama3.1", messages=history)
return response["message"]
96 changes: 27 additions & 69 deletions frontend/src/components/PersistentWindow.vue
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
{{ element.id }}
</span>
</div>
<div v-if="msg.ref">
<div v-if="msg.ref !== null">
<span class="small-tag blue" @click="searchRef(msg.ref)">
reference
</span>
Expand Down Expand Up @@ -74,7 +74,7 @@ export default {
api: {
chatbot: "api/subgraph/chatbot",
},
controller: null,
sourceToken: null,
};
},
computed: {
Expand Down Expand Up @@ -227,80 +227,38 @@ export default {
inputDiv.innerText = "";
}
},
async streamChatbotResponse(formData) {
let refData = null;
if (this.controller) {
this.controller.abort();
}

// Create a new AbortController instance
this.controller = new AbortController();
const signal = this.controller.signal; // Get the signal

const botMessage = {
sender: "Bot",
text: "Waiting for response...", // Initially empty, will be updated progressively
data: [...this.tags], // Add contextual data if needed
ref: null, // This will hold the pmids when received
};
this.messages.push(botMessage);

const response = await fetch(this.api.chatbot, {
method: "POST",
body: formData,
signal: signal,
});

if (!response.body) {
throw new Error("No response body");
}

const reader = response.body.getReader();
const decoder = new TextDecoder("utf-8");
let done = false;
let fullText = "";

// Index of the newly added Bot message
const botMessageIndex = this.messages.length - 1;

while (!done) {
const { value, done: readerDone } = await reader.read();
done = readerDone;
if (done) break;

// Decode the streamed data
const chunk = decoder.decode(value || new Uint8Array(), {
stream: !done,
});
// Parse the chunk as JSON to extract "messages" and "pmids"
let parsedChunk = JSON.parse(chunk);
// Check if it's a message part or pmids
if (parsedChunk.message) {
// Append message chunks to the fullText
fullText += parsedChunk.message;

// Update the bot message progressively
this.updateBotMessage(fullText, botMessageIndex);
}

if (parsedChunk.pmids) {
// If pmids are received, store them for later
refData = parsedChunk.pmids;
this.messages[botMessageIndex].ref = refData;
}
}
},
updateBotMessage(text, index) {
// Ensure we're updating the correct (newest) bot message by index
this.messages[index].text = text;
},
getAnswer(message) {
let com = this;
let formData = new FormData();
formData.append("message", message.text);
formData.append("background", JSON.stringify(message.data));

com.streamChatbotResponse(formData);
if (this.sourceToken) {
this.abort_chatbot();
}

this.messages.push({
sender: "Bot",
text: "Waiting for response...",
data: [...this.tags],
ref: null,
});
//POST request for generating pathways
com.sourceToken = this.axios.CancelToken.source();
com.axios
.post(com.api.chatbot, formData, {
cancelToken: com.sourceToken.token,
})
.then((response) => {
const botMessageIndex = this.messages.length - 1;
this.messages[botMessageIndex].ref = response.data.pmids;
this.messages[botMessageIndex].text = response.data.message;
this.sourceToken = null;
});
},
abort_chatbot() {
this.sourceToken.cancel("Request canceled");
},
closeWindow() {
this.windowCheck = false;
Expand Down
Loading