From 1f3fe40f189af0f160f7200a1cd956c5867fdfd3 Mon Sep 17 00:00:00 2001 From: Denis Diachkov Date: Mon, 9 Sep 2024 23:04:58 +0200 Subject: [PATCH] engines refactoring --- src/AGISwarm/llm_instruct_ms/app.py | 7 +-- src/AGISwarm/llm_instruct_ms/gui/scripts.js | 19 +++--- .../llm_instruct_ms/llm_engines/__init__.py | 2 +- .../llm_engines/{utils.py => engine.py} | 61 +++++++++++++++---- .../llm_instruct_ms/llm_engines/hf_engine.py | 34 +---------- .../llm_engines/llama_cpp_engine.py | 34 +---------- .../llm_engines/vllm_engine.py | 40 +----------- src/AGISwarm/llm_instruct_ms/typing.py | 1 + 8 files changed, 71 insertions(+), 127 deletions(-) rename src/AGISwarm/llm_instruct_ms/llm_engines/{utils.py => engine.py} (54%) diff --git a/src/AGISwarm/llm_instruct_ms/app.py b/src/AGISwarm/llm_instruct_ms/app.py index 1590c7f..a8a06d1 100644 --- a/src/AGISwarm/llm_instruct_ms/app.py +++ b/src/AGISwarm/llm_instruct_ms/app.py @@ -14,8 +14,7 @@ from omegaconf import OmegaConf from pydantic import BaseModel -from .llm_engines import EngineProtocol -from .llm_engines.utils import ConcurrentEngineProtocol +from .llm_engines import ConcurrentEngine, Engine from .typing import ( ENGINE_MAP, ENGINE_SAMPLING_PARAMS_MAP, @@ -32,7 +31,7 @@ def __init__(self, config: LLMInstructConfig): self.app = FastAPI() if config.engine_config is None: config.engine_config = cast(None, OmegaConf.create()) - self.llm_pipeline: EngineProtocol[Any] = ENGINE_MAP[config.engine]( # type: ignore + self.llm_pipeline: Engine[Any] = ENGINE_MAP[config.engine]( # type: ignore hf_model_name=config.hf_model_name, tokenizer_name=config.tokenizer_name, **cast(dict, OmegaConf.to_container(config.engine_config)), @@ -92,7 +91,7 @@ async def generate(self, websocket: WebSocket): # type: ignore queued_task = self.queue_manager.queued_generator( self.llm_pipeline.__call__, pass_task_id=isinstance( - self.llm_pipeline, ConcurrentEngineProtocol # type: ignore + self.llm_pipeline, ConcurrentEngine # type: ignore ), ) # task_id and interrupt_event are created by the queued_generator diff --git a/src/AGISwarm/llm_instruct_ms/gui/scripts.js b/src/AGISwarm/llm_instruct_ms/gui/scripts.js index 2ae2f37..6fac3b0 100644 --- a/src/AGISwarm/llm_instruct_ms/gui/scripts.js +++ b/src/AGISwarm/llm_instruct_ms/gui/scripts.js @@ -13,7 +13,6 @@ function resetForm() { document.getElementById('presence_penalty').value = DEFAULT_PRESENCE_PENALTY; document.getElementById('system_prompt').value = DEFAULT_SYSTEM_PROMPT; } -let currentMessage = ''; let currentRequestID = ''; @@ -33,7 +32,7 @@ function enableAbortButton() { document.getElementById('send-btn').disabled = false; } -function updateBotMessage(message) { +function updateBotMessage(message, replace = false) { const chatOutput = document.getElementById('chat-output'); // Check if the bot message div already exists let botMessageContainer = chatOutput.firstElementChild; @@ -47,7 +46,12 @@ function updateBotMessage(message) { botMessageContainer.appendChild(botMessage); chatOutput.insertBefore(botMessageContainer, chatOutput.firstChild); } - botMessage.textContent += message; + if (replace) { + botMessage.textContent = message; + } + else { + botMessage.textContent += message; + } botMessage.style.color = 'black'; const isAtBottom = chatOutput.scrollHeight - chatOutput.clientHeight <= chatOutput.scrollTop + 1; if (isAtBottom) { @@ -62,23 +66,20 @@ ws.onmessage = function (event) { response_dict = JSON.parse(event.data); console.log(response_dict); currentRequestID = JSON.parse(event.data)["task_id"]; - + switch (response_dict["status"]) { case "starting": - currentMessage = ''; disableGenerateButton(); return; case "finished": - currentMessage = ''; enableGenerateButton(); return; case "waiting": queue_pos = response_dict["queue_pos"]; - botMessage.textContent = "Waiting in queue. Position: " + queue_pos; - botMessage.style.color = 'blue'; + updateBotMessage("
" + "You are in position " + queue_pos + " in the queue", replace = true); enableAbortButton(); return; - case "abort": + case "aborted": updateBotMessage("
" + "Generation aborted"); enableAbortButton(); return; diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/__init__.py b/src/AGISwarm/llm_instruct_ms/llm_engines/__init__.py index bab79d8..ca96b5d 100644 --- a/src/AGISwarm/llm_instruct_ms/llm_engines/__init__.py +++ b/src/AGISwarm/llm_instruct_ms/llm_engines/__init__.py @@ -4,7 +4,7 @@ from typing import Protocol, runtime_checkable +from .engine import ConcurrentEngine, Engine from .hf_engine import HFEngine, HFSamplingParams from .llama_cpp_engine import LlamaCppEngine, LlamaCppSamplingParams -from .utils import EngineProtocol from .vllm_engine import VLLMEngine, VLLMSamplingParams diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/utils.py b/src/AGISwarm/llm_instruct_ms/llm_engines/engine.py similarity index 54% rename from src/AGISwarm/llm_instruct_ms/llm_engines/utils.py rename to src/AGISwarm/llm_instruct_ms/llm_engines/engine.py index e10efe9..e324c90 100644 --- a/src/AGISwarm/llm_instruct_ms/llm_engines/utils.py +++ b/src/AGISwarm/llm_instruct_ms/llm_engines/engine.py @@ -1,7 +1,7 @@ """Utility functions for LLM engines""" from abc import abstractmethod -from typing import Dict, Generic, List, Protocol, TypeVar, cast, runtime_checkable +from typing import Dict, Generic, List, TypeVar, cast from pydantic import BaseModel @@ -19,12 +19,13 @@ class SamplingParams(BaseModel): ) -@runtime_checkable # pylint: disable=too-few-public-methods -class EngineProtocol(Protocol, Generic[_SamplingParams_contra]): +class Engine(Generic[_SamplingParams_contra]): """Engine protocol""" - @abstractmethod + conversations: Dict[str, List[Dict[str, str]]] + + # pylint: disable=too-many-arguments async def __call__( self, conversation_id: str, @@ -33,8 +34,26 @@ async def __call__( reply_prefix: str, sampling_params: _SamplingParams_contra, ): - """Generate text from prompt""" - yield str() + if system_prompt != "": + self.conversations[conversation_id].append( + { + "role": "system", + "content": system_prompt, + } + ) + self.conversations[conversation_id].append({"role": "user", "content": prompt}) + reply: str = "" + async for response in self.generate( + self.conversations[conversation_id], + reply_prefix, + sampling_params, + ): + reply += response + yield response + self.conversations[conversation_id].append( + {"role": "assistant", "content": reply} + ) + yield "" @abstractmethod async def generate( @@ -47,12 +66,12 @@ async def generate( yield str() -@runtime_checkable # pylint: disable=too-few-public-methods -class ConcurrentEngineProtocol(Protocol, Generic[_SamplingParams_contra]): +class ConcurrentEngine(Generic[_SamplingParams_contra]): """Concurrent engine protocol""" - @abstractmethod + conversations: Dict[str, List[Dict[str, str]]] + # pylint: disable=too-many-arguments async def __call__( self, @@ -63,16 +82,34 @@ async def __call__( sampling_params: _SamplingParams_contra, task_id: str, ): - """Generate text from prompt""" - yield str() + if conversation_id not in self.conversations: + self.conversations[conversation_id] = [] + if system_prompt != "": + self.conversations[conversation_id].append( + { + "role": "system", + "content": system_prompt, + } + ) + self.conversations[conversation_id].append({"role": "user", "content": prompt}) + reply: str = "" + async for response in self.generate( + self.conversations[conversation_id], reply_prefix, sampling_params, task_id + ): + reply += response + yield response + self.conversations[conversation_id].append( + {"role": "assistant", "content": reply} + ) + yield "" @abstractmethod async def generate( self, - task_id: str, messages: List[Dict[str, str]], reply_prefix: str, sampling_params: _SamplingParams_contra, + task_id: str, ): """Generate text from prompt""" yield str() diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py b/src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py index 332c2c9..58cb08d 100644 --- a/src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py +++ b/src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py @@ -7,7 +7,7 @@ import transformers # type: ignore from pydantic import Field -from .utils import EngineProtocol, SamplingParams, prepare_prompt +from .engine import Engine, SamplingParams, prepare_prompt SUPPORTED_MODELS = [ "meta-llama/Meta-Llama-3-8B-Instruct", @@ -48,7 +48,7 @@ class HFSamplingParams(SamplingParams): # pylint: disable=too-few-public-methods -class HFEngine(EngineProtocol[HFSamplingParams]): # pylint: disable=invalid-name +class HFEngine(Engine[HFSamplingParams]): # pylint: disable=invalid-name """LLM Instruct Model Inference""" def __init__( @@ -100,33 +100,3 @@ async def generate( yield reply_prefix for new_text in streamer: yield cast(str, new_text) - - # pylint: disable=too-many-arguments - async def __call__( - self, - conversation_id: str, - prompt: str, - system_prompt: str, - reply_prefix: str, - sampling_params: HFSamplingParams, - ): - if system_prompt != "": - self.conversations[conversation_id].append( - { - "role": "system", - "content": system_prompt, - } - ) - self.conversations[conversation_id].append({"role": "user", "content": prompt}) - reply: str = "" - async for response in self.generate( - self.conversations[conversation_id], - reply_prefix, - sampling_params, - ): - reply += response - yield response - self.conversations[conversation_id].append( - {"role": "assistant", "content": reply} - ) - yield "" diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py b/src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py index c45849f..24ed4c8 100644 --- a/src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py +++ b/src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py @@ -6,7 +6,7 @@ from pydantic import Field from transformers import AutoTokenizer # type: ignore -from .utils import EngineProtocol, SamplingParams, prepare_prompt +from .engine import Engine, SamplingParams, prepare_prompt class LlamaCppSamplingParams(SamplingParams): @@ -17,7 +17,7 @@ class LlamaCppSamplingParams(SamplingParams): presence_penalty: float = Field(default=0.0, description="Presence penalty") -class LlamaCppEngine(EngineProtocol[LlamaCppSamplingParams]): +class LlamaCppEngine(Engine[LlamaCppSamplingParams]): """LLM Instruct Model Inference""" def __init__( # pylint: disable=too-many-arguments @@ -63,33 +63,3 @@ async def generate( ): output = cast(CreateCompletionStreamResponse, output) yield output["choices"][0]["text"] - - # pylint: disable=too-many-arguments - async def __call__( - self, - conversation_id: str, - prompt: str, - system_prompt: str, - reply_prefix: str, - sampling_params: LlamaCppSamplingParams, - ): - if system_prompt != "": - self.conversations[conversation_id].append( - { - "role": "system", - "content": system_prompt, - } - ) - self.conversations[conversation_id].append({"role": "user", "content": prompt}) - reply: str = "" - async for response in self.generate( - self.conversations[conversation_id], - reply_prefix, - sampling_params, - ): - reply += response - yield response - self.conversations[conversation_id].append( - {"role": "assistant", "content": reply} - ) - yield "" diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py b/src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py index 5592af7..5df0e1e 100644 --- a/src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py +++ b/src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py @@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download from pydantic import Field -from .utils import ConcurrentEngineProtocol, SamplingParams, prepare_prompt +from .engine import ConcurrentEngine, SamplingParams, prepare_prompt class VLLMSamplingParams(SamplingParams): @@ -19,7 +19,7 @@ class VLLMSamplingParams(SamplingParams): presence_penalty: float = Field(default=0.0, description="Presence penalty") -class VLLMEngine(ConcurrentEngineProtocol[VLLMSamplingParams]): +class VLLMEngine(ConcurrentEngine[VLLMSamplingParams]): """LLM Instruct Model Inference using VLLM""" def __init__( @@ -64,10 +64,10 @@ def get_sampling_params( async def generate( self, - task_id: str, messages: list[dict], reply_prefix: str | None, sampling_params: VLLMSamplingParams, + task_id: str, ): """Generate text from prompt""" prompt = prepare_prompt(self.tokenizer, messages, reply_prefix) @@ -82,37 +82,3 @@ async def generate( current_len = len(output.outputs[0].text) if output.finished: break - - # pylint: disable=too-many-arguments - async def __call__( - self, - conversation_id: str, - prompt: str, - system_prompt: str, - reply_prefix: str, - sampling_params: VLLMSamplingParams, - task_id: str, - ): - if conversation_id not in self.conversations: - self.conversations[conversation_id] = [] - if system_prompt != "": - self.conversations[conversation_id].append( - { - "role": "system", - "content": system_prompt, - } - ) - self.conversations[conversation_id].append({"role": "user", "content": prompt}) - reply: str = "" - async for response in self.generate( - task_id, - self.conversations[conversation_id], - reply_prefix, - sampling_params, - ): - reply += response - yield response - self.conversations[conversation_id].append( - {"role": "assistant", "content": reply} - ) - yield "" diff --git a/src/AGISwarm/llm_instruct_ms/typing.py b/src/AGISwarm/llm_instruct_ms/typing.py index 8627b6f..fca48ec 100644 --- a/src/AGISwarm/llm_instruct_ms/typing.py +++ b/src/AGISwarm/llm_instruct_ms/typing.py @@ -38,6 +38,7 @@ class VLLMConfig(ModelConfig): filename: str | None = None + class HFConfig(ModelConfig): """HF settings"""