Skip to content

Commit

Permalink
NN-597 Implement chatbot memory
Browse files Browse the repository at this point in the history
  • Loading branch information
Maluuck committed Oct 21, 2024
1 parent 8e329eb commit c7dc67c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 26 deletions.
46 changes: 39 additions & 7 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from werkzeug.middleware.proxy_fix import ProxyFix

app = Flask(__name__)

history = []
# ====================== Index page ======================

_SCRIPT_DIR = os.path.dirname(__file__)
Expand Down Expand Up @@ -138,6 +138,33 @@ def abstract_summary():
return Response(json.dumps(response), mimetype="application/json")


def make_prompt(message, proteins, funct_terms, abstract):
"""
Create a prompt for the chatbot.
Args:
message: Input message from user.
proteins: User selected proteins to be included in the prompt.
funct_terms: User selected functional terms to be included in the prompt.
abstract: User selected abstracts to be included in the prompt.
Returns:
prompt: The prompt to be used for response generation.
"""
pro = "use the following proteins:" if len(proteins) > 0 else ""
funct = "use the following functional terms:" if len(funct_terms) > 0 else ""
abstract_is = (
"use the following abstracts and state PMID if you use them for information:"
if len(abstract) > 0
else ""
)
proteins = proteins if len(proteins) > 0 else ""
funct_terms = funct_terms if len(funct_terms) > 0 else ""
abstract = abstract if len(abstract) > 0 else ""
prompt = (
f"{message} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}"
)
return prompt


@app.route("/api/subgraph/chatbot", methods=["POST"])
def chatbot_response():
message, background = (request.form.get("message"), request.form.get("background"))
Expand Down Expand Up @@ -168,16 +195,21 @@ def chatbot_response():
embedded_query, pmids_embeddings, 6
)
stopwatch.round("Vector search")
query = ""
abstracts = ""
for i in top_n_similiar:
query += f"PMID: {i} Abstract: {pmid_abstract[i]} "
answer = create_summary_RAG(
str(message),
abstracts += f"PMID: {i} Abstract: {pmid_abstract[i]} "
message = make_prompt(
message=message,
proteins=protein_list,
funct_terms=funct_terms_list,
abstract=query,
abstract=abstracts,
)
history.append({"role": "user", "content": message})
answer = create_summary_RAG(
history=history,
)
response = json.dumps({"message": answer, "pmids": top_n_similiar})
history.append(answer)
response = json.dumps({"message": answer["content"], "pmids": top_n_similiar})
stopwatch.round("Generating answer")
return Response(response, mimetype="application/json")

Expand Down
22 changes: 3 additions & 19 deletions backend/src/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,6 @@ def get_response(prompt):
return [summary]


def create_summary_RAG(query, proteins, funct_terms, abstract):
pro = "use the following proteins:" if len(proteins) > 0 else ""
funct = "use the following functional terms:" if len(funct_terms) > 0 else ""
abstract_is = (
"use the following abstracts and state PMID if you use them for information:"
if len(abstract) > 0
else ""
)
proteins = proteins if len(proteins) > 0 else ""
funct_terms = funct_terms if len(funct_terms) > 0 else ""
abstract = abstract if len(abstract) > 0 else ""

def get_response(prompt):
response = ollama.generate(model="llama3.1", prompt=prompt, stream=False)
return response["response"]

prompt = f"{query} {pro} {proteins} {funct} {funct_terms} {abstract_is} {abstract}"
summary = get_response(prompt)
return summary
def create_summary_RAG(history):
response = ollama.chat(model="llama3.1", messages=history)
return response["message"]

0 comments on commit c7dc67c

Please sign in to comment.