Skip to content

Commit

Permalink
engines refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Sep 9, 2024
1 parent ff8f48e commit 1f3fe40
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 127 deletions.
7 changes: 3 additions & 4 deletions src/AGISwarm/llm_instruct_ms/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from omegaconf import OmegaConf
from pydantic import BaseModel

from .llm_engines import EngineProtocol
from .llm_engines.utils import ConcurrentEngineProtocol
from .llm_engines import ConcurrentEngine, Engine
from .typing import (
ENGINE_MAP,
ENGINE_SAMPLING_PARAMS_MAP,
Expand All @@ -32,7 +31,7 @@ def __init__(self, config: LLMInstructConfig):
self.app = FastAPI()
if config.engine_config is None:
config.engine_config = cast(None, OmegaConf.create())
self.llm_pipeline: EngineProtocol[Any] = ENGINE_MAP[config.engine]( # type: ignore
self.llm_pipeline: Engine[Any] = ENGINE_MAP[config.engine]( # type: ignore
hf_model_name=config.hf_model_name,
tokenizer_name=config.tokenizer_name,
**cast(dict, OmegaConf.to_container(config.engine_config)),
Expand Down Expand Up @@ -92,7 +91,7 @@ async def generate(self, websocket: WebSocket): # type: ignore
queued_task = self.queue_manager.queued_generator(
self.llm_pipeline.__call__,
pass_task_id=isinstance(
self.llm_pipeline, ConcurrentEngineProtocol # type: ignore
self.llm_pipeline, ConcurrentEngine # type: ignore
),
)
# task_id and interrupt_event are created by the queued_generator
Expand Down
19 changes: 10 additions & 9 deletions src/AGISwarm/llm_instruct_ms/gui/scripts.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ function resetForm() {
document.getElementById('presence_penalty').value = DEFAULT_PRESENCE_PENALTY;
document.getElementById('system_prompt').value = DEFAULT_SYSTEM_PROMPT;
}
let currentMessage = '';
let currentRequestID = '';


Expand All @@ -33,7 +32,7 @@ function enableAbortButton() {
document.getElementById('send-btn').disabled = false;
}

function updateBotMessage(message) {
function updateBotMessage(message, replace = false) {
const chatOutput = document.getElementById('chat-output');
// Check if the bot message div already exists
let botMessageContainer = chatOutput.firstElementChild;
Expand All @@ -47,7 +46,12 @@ function updateBotMessage(message) {
botMessageContainer.appendChild(botMessage);
chatOutput.insertBefore(botMessageContainer, chatOutput.firstChild);
}
botMessage.textContent += message;
if (replace) {
botMessage.textContent = message;
}
else {
botMessage.textContent += message;
}
botMessage.style.color = 'black';
const isAtBottom = chatOutput.scrollHeight - chatOutput.clientHeight <= chatOutput.scrollTop + 1;
if (isAtBottom) {
Expand All @@ -62,23 +66,20 @@ ws.onmessage = function (event) {
response_dict = JSON.parse(event.data);
console.log(response_dict);
currentRequestID = JSON.parse(event.data)["task_id"];

switch (response_dict["status"]) {
case "starting":
currentMessage = '';
disableGenerateButton();
return;
case "finished":
currentMessage = '';
enableGenerateButton();
return;
case "waiting":
queue_pos = response_dict["queue_pos"];
botMessage.textContent = "Waiting in queue. Position: " + queue_pos;
botMessage.style.color = 'blue';
updateBotMessage("<br>" + "<span style='color:blue;'>You are in position " + queue_pos + " in the queue</span>", replace = true);
enableAbortButton();
return;
case "abort":
case "aborted":
updateBotMessage("<br>" + "<span style='color:red;'>Generation aborted</span>");
enableAbortButton();
return;
Expand Down
2 changes: 1 addition & 1 deletion src/AGISwarm/llm_instruct_ms/llm_engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Protocol, runtime_checkable

from .engine import ConcurrentEngine, Engine
from .hf_engine import HFEngine, HFSamplingParams
from .llama_cpp_engine import LlamaCppEngine, LlamaCppSamplingParams
from .utils import EngineProtocol
from .vllm_engine import VLLMEngine, VLLMSamplingParams
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utility functions for LLM engines"""

from abc import abstractmethod
from typing import Dict, Generic, List, Protocol, TypeVar, cast, runtime_checkable
from typing import Dict, Generic, List, TypeVar, cast

from pydantic import BaseModel

Expand All @@ -19,12 +19,13 @@ class SamplingParams(BaseModel):
)


@runtime_checkable
# pylint: disable=too-few-public-methods
class EngineProtocol(Protocol, Generic[_SamplingParams_contra]):
class Engine(Generic[_SamplingParams_contra]):
"""Engine protocol"""

@abstractmethod
conversations: Dict[str, List[Dict[str, str]]]

# pylint: disable=too-many-arguments
async def __call__(
self,
conversation_id: str,
Expand All @@ -33,8 +34,26 @@ async def __call__(
reply_prefix: str,
sampling_params: _SamplingParams_contra,
):
"""Generate text from prompt"""
yield str()
if system_prompt != "":
self.conversations[conversation_id].append(
{
"role": "system",
"content": system_prompt,
}
)
self.conversations[conversation_id].append({"role": "user", "content": prompt})
reply: str = ""
async for response in self.generate(
self.conversations[conversation_id],
reply_prefix,
sampling_params,
):
reply += response
yield response
self.conversations[conversation_id].append(
{"role": "assistant", "content": reply}
)
yield ""

@abstractmethod
async def generate(
Expand All @@ -47,12 +66,12 @@ async def generate(
yield str()


@runtime_checkable
# pylint: disable=too-few-public-methods
class ConcurrentEngineProtocol(Protocol, Generic[_SamplingParams_contra]):
class ConcurrentEngine(Generic[_SamplingParams_contra]):
"""Concurrent engine protocol"""

@abstractmethod
conversations: Dict[str, List[Dict[str, str]]]

# pylint: disable=too-many-arguments
async def __call__(
self,
Expand All @@ -63,16 +82,34 @@ async def __call__(
sampling_params: _SamplingParams_contra,
task_id: str,
):
"""Generate text from prompt"""
yield str()
if conversation_id not in self.conversations:
self.conversations[conversation_id] = []
if system_prompt != "":
self.conversations[conversation_id].append(
{
"role": "system",
"content": system_prompt,
}
)
self.conversations[conversation_id].append({"role": "user", "content": prompt})
reply: str = ""
async for response in self.generate(
self.conversations[conversation_id], reply_prefix, sampling_params, task_id
):
reply += response
yield response
self.conversations[conversation_id].append(
{"role": "assistant", "content": reply}
)
yield ""

@abstractmethod
async def generate(
self,
task_id: str,
messages: List[Dict[str, str]],
reply_prefix: str,
sampling_params: _SamplingParams_contra,
task_id: str,
):
"""Generate text from prompt"""
yield str()
Expand Down
34 changes: 2 additions & 32 deletions src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import transformers # type: ignore
from pydantic import Field

from .utils import EngineProtocol, SamplingParams, prepare_prompt
from .engine import Engine, SamplingParams, prepare_prompt

SUPPORTED_MODELS = [
"meta-llama/Meta-Llama-3-8B-Instruct",
Expand Down Expand Up @@ -48,7 +48,7 @@ class HFSamplingParams(SamplingParams):


# pylint: disable=too-few-public-methods
class HFEngine(EngineProtocol[HFSamplingParams]): # pylint: disable=invalid-name
class HFEngine(Engine[HFSamplingParams]): # pylint: disable=invalid-name
"""LLM Instruct Model Inference"""

def __init__(
Expand Down Expand Up @@ -100,33 +100,3 @@ async def generate(
yield reply_prefix
for new_text in streamer:
yield cast(str, new_text)

# pylint: disable=too-many-arguments
async def __call__(
self,
conversation_id: str,
prompt: str,
system_prompt: str,
reply_prefix: str,
sampling_params: HFSamplingParams,
):
if system_prompt != "":
self.conversations[conversation_id].append(
{
"role": "system",
"content": system_prompt,
}
)
self.conversations[conversation_id].append({"role": "user", "content": prompt})
reply: str = ""
async for response in self.generate(
self.conversations[conversation_id],
reply_prefix,
sampling_params,
):
reply += response
yield response
self.conversations[conversation_id].append(
{"role": "assistant", "content": reply}
)
yield ""
34 changes: 2 additions & 32 deletions src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import Field
from transformers import AutoTokenizer # type: ignore

from .utils import EngineProtocol, SamplingParams, prepare_prompt
from .engine import Engine, SamplingParams, prepare_prompt


class LlamaCppSamplingParams(SamplingParams):
Expand All @@ -17,7 +17,7 @@ class LlamaCppSamplingParams(SamplingParams):
presence_penalty: float = Field(default=0.0, description="Presence penalty")


class LlamaCppEngine(EngineProtocol[LlamaCppSamplingParams]):
class LlamaCppEngine(Engine[LlamaCppSamplingParams]):
"""LLM Instruct Model Inference"""

def __init__( # pylint: disable=too-many-arguments
Expand Down Expand Up @@ -63,33 +63,3 @@ async def generate(
):
output = cast(CreateCompletionStreamResponse, output)
yield output["choices"][0]["text"]

# pylint: disable=too-many-arguments
async def __call__(
self,
conversation_id: str,
prompt: str,
system_prompt: str,
reply_prefix: str,
sampling_params: LlamaCppSamplingParams,
):
if system_prompt != "":
self.conversations[conversation_id].append(
{
"role": "system",
"content": system_prompt,
}
)
self.conversations[conversation_id].append({"role": "user", "content": prompt})
reply: str = ""
async for response in self.generate(
self.conversations[conversation_id],
reply_prefix,
sampling_params,
):
reply += response
yield response
self.conversations[conversation_id].append(
{"role": "assistant", "content": reply}
)
yield ""
40 changes: 3 additions & 37 deletions src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from huggingface_hub import hf_hub_download
from pydantic import Field

from .utils import ConcurrentEngineProtocol, SamplingParams, prepare_prompt
from .engine import ConcurrentEngine, SamplingParams, prepare_prompt


class VLLMSamplingParams(SamplingParams):
Expand All @@ -19,7 +19,7 @@ class VLLMSamplingParams(SamplingParams):
presence_penalty: float = Field(default=0.0, description="Presence penalty")


class VLLMEngine(ConcurrentEngineProtocol[VLLMSamplingParams]):
class VLLMEngine(ConcurrentEngine[VLLMSamplingParams]):
"""LLM Instruct Model Inference using VLLM"""

def __init__(
Expand Down Expand Up @@ -64,10 +64,10 @@ def get_sampling_params(

async def generate(
self,
task_id: str,
messages: list[dict],
reply_prefix: str | None,
sampling_params: VLLMSamplingParams,
task_id: str,
):
"""Generate text from prompt"""
prompt = prepare_prompt(self.tokenizer, messages, reply_prefix)
Expand All @@ -82,37 +82,3 @@ async def generate(
current_len = len(output.outputs[0].text)
if output.finished:
break

# pylint: disable=too-many-arguments
async def __call__(
self,
conversation_id: str,
prompt: str,
system_prompt: str,
reply_prefix: str,
sampling_params: VLLMSamplingParams,
task_id: str,
):
if conversation_id not in self.conversations:
self.conversations[conversation_id] = []
if system_prompt != "":
self.conversations[conversation_id].append(
{
"role": "system",
"content": system_prompt,
}
)
self.conversations[conversation_id].append({"role": "user", "content": prompt})
reply: str = ""
async for response in self.generate(
task_id,
self.conversations[conversation_id],
reply_prefix,
sampling_params,
):
reply += response
yield response
self.conversations[conversation_id].append(
{"role": "assistant", "content": reply}
)
yield ""
1 change: 1 addition & 0 deletions src/AGISwarm/llm_instruct_ms/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class VLLMConfig(ModelConfig):

filename: str | None = None


class HFConfig(ModelConfig):
"""HF settings"""

Expand Down

0 comments on commit 1f3fe40

Please sign in to comment.