From 4057dfb69796527aac05345d7ff107cd0ebbb5c3 Mon Sep 17 00:00:00 2001 From: Vuk Manojlovic Date: Wed, 4 Sep 2024 10:51:18 +0200 Subject: [PATCH 1/2] Hotfix: Added streaming inference to ollama-chatbot-fn --- tasks/ollama-chatbot-fn/main.py | 6 +++- .../resources/function/function.py | 35 ++++++++++++++++--- tasks/ollama-chatbot-fn/task.yaml | 7 ++++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/tasks/ollama-chatbot-fn/main.py b/tasks/ollama-chatbot-fn/main.py index 9fdaace2..6276ff93 100644 --- a/tasks/ollama-chatbot-fn/main.py +++ b/tasks/ollama-chatbot-fn/main.py @@ -1,6 +1,7 @@ from pathlib import Path import os +import json import shutil import logging @@ -21,7 +22,7 @@ def getIndexPath(dataset: CustomDataset) -> Path: def main() -> None: taskRun = currentTaskRun() - model = Model.createModel(f"{taskRun.id}-rag-chatbot", taskRun.id, 1.0) + model = Model.createModel(f"{taskRun.id}-rag-chatbot", taskRun.projectId, 1.0) modelFunction = Path(".", "resources", "function") resDir = folder_manager.createTempFolder("resDir") @@ -35,6 +36,9 @@ def main() -> None: copyDir(sample.path, resDir, "corpus-index") + with resDir.joinpath("metadata.json").open("w") as file: + json.dump({"streaming": taskRun.parameters["streaming"]}, file) + model.upload(resDir) logging.info(">> [DocumentOCR] Model deployed \U0001F680\U0001F680\U0001F680") diff --git a/tasks/ollama-chatbot-fn/resources/function/function.py b/tasks/ollama-chatbot-fn/resources/function/function.py index 233a344d..f8068b01 100644 --- a/tasks/ollama-chatbot-fn/resources/function/function.py +++ b/tasks/ollama-chatbot-fn/resources/function/function.py @@ -23,6 +23,32 @@ if rag: corpus, index = loadCorpusAndIndex(indexDir) +with Path.cwd().parent.joinpath("metadata.json").open("r") as file: + metadata = json.load(file) + streaming = metadata.get("streaming", False) + + +def saveChatHistory(messages: list[dict[str, str]], sessionPath: Path): + with sessionPath.open("w") as file: + json.dump(messages, file) + + +def streamingChat(messages: list[dict[str, str]], sessionPath: Path): + fullResponse = "" + response = ollama.chat( + model = "llama3", + messages = messages, + stream = True + ) + + for chunk in response: + content = chunk["message"]["content"] + fullResponse += content + yield functions.success({"result": content}) + + messages.append({"role": "assistant", "content": fullResponse}) + saveChatHistory(messages, sessionPath) + def response(requestData: dict[str, Any]) -> dict[str, Any]: inputSessionId = requestData.get("session_id") @@ -62,16 +88,17 @@ def response(requestData: dict[str, Any]) -> dict[str, Any]: "content": query }) - logging.debug(">>> Running inference on LLM") + logging.debug(f">>> Running {'streaming' if streaming else 'batch'} inference on LLM") + if streaming: + return streamingChat(messages, sessionPath) + response = ollama.chat( model = LLM, messages = messages ) messages.append(response["message"]) - - with sessionPath.open("w") as file: - json.dump(messages, file) + saveChatHistory(messages, sessionPath) answer = response["message"]["content"] logging.debug(f">>> Returning response:\n{answer}") diff --git a/tasks/ollama-chatbot-fn/task.yaml b/tasks/ollama-chatbot-fn/task.yaml index 6e7945f5..117bd88e 100644 --- a/tasks/ollama-chatbot-fn/task.yaml +++ b/tasks/ollama-chatbot-fn/task.yaml @@ -11,3 +11,10 @@ param_groups: value: null data_type: dataset required: false +- name: parameters + params: + - name: streaming + description: If the chatbot will return a streaming response (streaming node) + value: true + data_type: bool + required: true From cb767d061d81a466cd0bfcbd0b4e89c03dc334ae Mon Sep 17 00:00:00 2001 From: Vuk Manojlovic Date: Wed, 4 Sep 2024 11:18:13 +0200 Subject: [PATCH 2/2] Hotfix: Fixed linter errors --- tasks/ollama-chatbot-fn/resources/function/function.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tasks/ollama-chatbot-fn/resources/function/function.py b/tasks/ollama-chatbot-fn/resources/function/function.py index f8068b01..1257b423 100644 --- a/tasks/ollama-chatbot-fn/resources/function/function.py +++ b/tasks/ollama-chatbot-fn/resources/function/function.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Generator, Union from pathlib import Path import uuid @@ -28,12 +28,12 @@ streaming = metadata.get("streaming", False) -def saveChatHistory(messages: list[dict[str, str]], sessionPath: Path): +def saveChatHistory(messages: list[dict[str, str]], sessionPath: Path) -> None: with sessionPath.open("w") as file: json.dump(messages, file) -def streamingChat(messages: list[dict[str, str]], sessionPath: Path): +def streamingChat(messages: list[dict[str, str]], sessionPath: Path) -> Generator[dict[str, Any], Any, None]: fullResponse = "" response = ollama.chat( model = "llama3", @@ -50,7 +50,7 @@ def streamingChat(messages: list[dict[str, str]], sessionPath: Path): saveChatHistory(messages, sessionPath) -def response(requestData: dict[str, Any]) -> dict[str, Any]: +def response(requestData: dict[str, Any]) -> Union[dict[str, Any], Generator[dict[str, Any], Any, None]]: inputSessionId = requestData.get("session_id") sessionId = str(uuid.uuid1()) if inputSessionId is None else inputSessionId sessionPath = memoryFolder / f"{sessionId}.json"