Skip to content

Commit

Permalink
Merge pull request #190 from VukManojlovic/hotfix-chatbot-streaming
Browse files Browse the repository at this point in the history
Hotfix: Added streaming inference to ollama-chatbot-fn
  • Loading branch information
igorperic17 authored Sep 4, 2024
2 parents 4c43e95 + cb767d0 commit f97f801
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
6 changes: 5 additions & 1 deletion tasks/ollama-chatbot-fn/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import os
import json
import shutil
import logging

Expand All @@ -21,7 +22,7 @@ def getIndexPath(dataset: CustomDataset) -> Path:
def main() -> None:
taskRun = currentTaskRun()

model = Model.createModel(f"{taskRun.id}-rag-chatbot", taskRun.id, 1.0)
model = Model.createModel(f"{taskRun.id}-rag-chatbot", taskRun.projectId, 1.0)

modelFunction = Path(".", "resources", "function")
resDir = folder_manager.createTempFolder("resDir")
Expand All @@ -35,6 +36,9 @@ def main() -> None:

copyDir(sample.path, resDir, "corpus-index")

with resDir.joinpath("metadata.json").open("w") as file:
json.dump({"streaming": taskRun.parameters["streaming"]}, file)

model.upload(resDir)

logging.info(">> [DocumentOCR] Model deployed \U0001F680\U0001F680\U0001F680")
Expand Down
39 changes: 33 additions & 6 deletions tasks/ollama-chatbot-fn/resources/function/function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Generator, Union
from pathlib import Path

import uuid
Expand All @@ -23,8 +23,34 @@
if rag:
corpus, index = loadCorpusAndIndex(indexDir)

with Path.cwd().parent.joinpath("metadata.json").open("r") as file:
metadata = json.load(file)
streaming = metadata.get("streaming", False)

def response(requestData: dict[str, Any]) -> dict[str, Any]:

def saveChatHistory(messages: list[dict[str, str]], sessionPath: Path) -> None:
with sessionPath.open("w") as file:
json.dump(messages, file)


def streamingChat(messages: list[dict[str, str]], sessionPath: Path) -> Generator[dict[str, Any], Any, None]:
fullResponse = ""
response = ollama.chat(
model = "llama3",
messages = messages,
stream = True
)

for chunk in response:
content = chunk["message"]["content"]
fullResponse += content
yield functions.success({"result": content})

messages.append({"role": "assistant", "content": fullResponse})
saveChatHistory(messages, sessionPath)


def response(requestData: dict[str, Any]) -> Union[dict[str, Any], Generator[dict[str, Any], Any, None]]:
inputSessionId = requestData.get("session_id")
sessionId = str(uuid.uuid1()) if inputSessionId is None else inputSessionId
sessionPath = memoryFolder / f"{sessionId}.json"
Expand Down Expand Up @@ -62,16 +88,17 @@ def response(requestData: dict[str, Any]) -> dict[str, Any]:
"content": query
})

logging.debug(">>> Running inference on LLM")
logging.debug(f">>> Running {'streaming' if streaming else 'batch'} inference on LLM")
if streaming:
return streamingChat(messages, sessionPath)

response = ollama.chat(
model = LLM,
messages = messages
)

messages.append(response["message"])

with sessionPath.open("w") as file:
json.dump(messages, file)
saveChatHistory(messages, sessionPath)

answer = response["message"]["content"]
logging.debug(f">>> Returning response:\n{answer}")
Expand Down
7 changes: 7 additions & 0 deletions tasks/ollama-chatbot-fn/task.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ param_groups:
value: null
data_type: dataset
required: false
- name: parameters
params:
- name: streaming
description: If the chatbot will return a streaming response (streaming node)
value: true
data_type: bool
required: true

0 comments on commit f97f801

Please sign in to comment.