diff --git a/src/AGISwarm/llm_instruct_ms/app.py b/src/AGISwarm/llm_instruct_ms/app.py
index 1590c7f..a8a06d1 100644
--- a/src/AGISwarm/llm_instruct_ms/app.py
+++ b/src/AGISwarm/llm_instruct_ms/app.py
@@ -14,8 +14,7 @@
from omegaconf import OmegaConf
from pydantic import BaseModel
-from .llm_engines import EngineProtocol
-from .llm_engines.utils import ConcurrentEngineProtocol
+from .llm_engines import ConcurrentEngine, Engine
from .typing import (
ENGINE_MAP,
ENGINE_SAMPLING_PARAMS_MAP,
@@ -32,7 +31,7 @@ def __init__(self, config: LLMInstructConfig):
self.app = FastAPI()
if config.engine_config is None:
config.engine_config = cast(None, OmegaConf.create())
- self.llm_pipeline: EngineProtocol[Any] = ENGINE_MAP[config.engine]( # type: ignore
+ self.llm_pipeline: Engine[Any] = ENGINE_MAP[config.engine]( # type: ignore
hf_model_name=config.hf_model_name,
tokenizer_name=config.tokenizer_name,
**cast(dict, OmegaConf.to_container(config.engine_config)),
@@ -92,7 +91,7 @@ async def generate(self, websocket: WebSocket): # type: ignore
queued_task = self.queue_manager.queued_generator(
self.llm_pipeline.__call__,
pass_task_id=isinstance(
- self.llm_pipeline, ConcurrentEngineProtocol # type: ignore
+ self.llm_pipeline, ConcurrentEngine # type: ignore
),
)
# task_id and interrupt_event are created by the queued_generator
diff --git a/src/AGISwarm/llm_instruct_ms/gui/scripts.js b/src/AGISwarm/llm_instruct_ms/gui/scripts.js
index 2ae2f37..6fac3b0 100644
--- a/src/AGISwarm/llm_instruct_ms/gui/scripts.js
+++ b/src/AGISwarm/llm_instruct_ms/gui/scripts.js
@@ -13,7 +13,6 @@ function resetForm() {
document.getElementById('presence_penalty').value = DEFAULT_PRESENCE_PENALTY;
document.getElementById('system_prompt').value = DEFAULT_SYSTEM_PROMPT;
}
-let currentMessage = '';
let currentRequestID = '';
@@ -33,7 +32,7 @@ function enableAbortButton() {
document.getElementById('send-btn').disabled = false;
}
-function updateBotMessage(message) {
+function updateBotMessage(message, replace = false) {
const chatOutput = document.getElementById('chat-output');
// Check if the bot message div already exists
let botMessageContainer = chatOutput.firstElementChild;
@@ -47,7 +46,12 @@ function updateBotMessage(message) {
botMessageContainer.appendChild(botMessage);
chatOutput.insertBefore(botMessageContainer, chatOutput.firstChild);
}
- botMessage.textContent += message;
+ if (replace) {
+ botMessage.textContent = message;
+ }
+ else {
+ botMessage.textContent += message;
+ }
botMessage.style.color = 'black';
const isAtBottom = chatOutput.scrollHeight - chatOutput.clientHeight <= chatOutput.scrollTop + 1;
if (isAtBottom) {
@@ -62,23 +66,20 @@ ws.onmessage = function (event) {
response_dict = JSON.parse(event.data);
console.log(response_dict);
currentRequestID = JSON.parse(event.data)["task_id"];
-
+
switch (response_dict["status"]) {
case "starting":
- currentMessage = '';
disableGenerateButton();
return;
case "finished":
- currentMessage = '';
enableGenerateButton();
return;
case "waiting":
queue_pos = response_dict["queue_pos"];
- botMessage.textContent = "Waiting in queue. Position: " + queue_pos;
- botMessage.style.color = 'blue';
+ updateBotMessage("
" + "You are in position " + queue_pos + " in the queue", replace = true);
enableAbortButton();
return;
- case "abort":
+ case "aborted":
updateBotMessage("
" + "Generation aborted");
enableAbortButton();
return;
diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/__init__.py b/src/AGISwarm/llm_instruct_ms/llm_engines/__init__.py
index bab79d8..ca96b5d 100644
--- a/src/AGISwarm/llm_instruct_ms/llm_engines/__init__.py
+++ b/src/AGISwarm/llm_instruct_ms/llm_engines/__init__.py
@@ -4,7 +4,7 @@
from typing import Protocol, runtime_checkable
+from .engine import ConcurrentEngine, Engine
from .hf_engine import HFEngine, HFSamplingParams
from .llama_cpp_engine import LlamaCppEngine, LlamaCppSamplingParams
-from .utils import EngineProtocol
from .vllm_engine import VLLMEngine, VLLMSamplingParams
diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/utils.py b/src/AGISwarm/llm_instruct_ms/llm_engines/engine.py
similarity index 54%
rename from src/AGISwarm/llm_instruct_ms/llm_engines/utils.py
rename to src/AGISwarm/llm_instruct_ms/llm_engines/engine.py
index e10efe9..e324c90 100644
--- a/src/AGISwarm/llm_instruct_ms/llm_engines/utils.py
+++ b/src/AGISwarm/llm_instruct_ms/llm_engines/engine.py
@@ -1,7 +1,7 @@
"""Utility functions for LLM engines"""
from abc import abstractmethod
-from typing import Dict, Generic, List, Protocol, TypeVar, cast, runtime_checkable
+from typing import Dict, Generic, List, TypeVar, cast
from pydantic import BaseModel
@@ -19,12 +19,13 @@ class SamplingParams(BaseModel):
)
-@runtime_checkable
# pylint: disable=too-few-public-methods
-class EngineProtocol(Protocol, Generic[_SamplingParams_contra]):
+class Engine(Generic[_SamplingParams_contra]):
"""Engine protocol"""
- @abstractmethod
+ conversations: Dict[str, List[Dict[str, str]]]
+
+ # pylint: disable=too-many-arguments
async def __call__(
self,
conversation_id: str,
@@ -33,8 +34,26 @@ async def __call__(
reply_prefix: str,
sampling_params: _SamplingParams_contra,
):
- """Generate text from prompt"""
- yield str()
+ if system_prompt != "":
+ self.conversations[conversation_id].append(
+ {
+ "role": "system",
+ "content": system_prompt,
+ }
+ )
+ self.conversations[conversation_id].append({"role": "user", "content": prompt})
+ reply: str = ""
+ async for response in self.generate(
+ self.conversations[conversation_id],
+ reply_prefix,
+ sampling_params,
+ ):
+ reply += response
+ yield response
+ self.conversations[conversation_id].append(
+ {"role": "assistant", "content": reply}
+ )
+ yield ""
@abstractmethod
async def generate(
@@ -47,12 +66,12 @@ async def generate(
yield str()
-@runtime_checkable
# pylint: disable=too-few-public-methods
-class ConcurrentEngineProtocol(Protocol, Generic[_SamplingParams_contra]):
+class ConcurrentEngine(Generic[_SamplingParams_contra]):
"""Concurrent engine protocol"""
- @abstractmethod
+ conversations: Dict[str, List[Dict[str, str]]]
+
# pylint: disable=too-many-arguments
async def __call__(
self,
@@ -63,16 +82,34 @@ async def __call__(
sampling_params: _SamplingParams_contra,
task_id: str,
):
- """Generate text from prompt"""
- yield str()
+ if conversation_id not in self.conversations:
+ self.conversations[conversation_id] = []
+ if system_prompt != "":
+ self.conversations[conversation_id].append(
+ {
+ "role": "system",
+ "content": system_prompt,
+ }
+ )
+ self.conversations[conversation_id].append({"role": "user", "content": prompt})
+ reply: str = ""
+ async for response in self.generate(
+ self.conversations[conversation_id], reply_prefix, sampling_params, task_id
+ ):
+ reply += response
+ yield response
+ self.conversations[conversation_id].append(
+ {"role": "assistant", "content": reply}
+ )
+ yield ""
@abstractmethod
async def generate(
self,
- task_id: str,
messages: List[Dict[str, str]],
reply_prefix: str,
sampling_params: _SamplingParams_contra,
+ task_id: str,
):
"""Generate text from prompt"""
yield str()
diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py b/src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py
index 332c2c9..58cb08d 100644
--- a/src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py
+++ b/src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py
@@ -7,7 +7,7 @@
import transformers # type: ignore
from pydantic import Field
-from .utils import EngineProtocol, SamplingParams, prepare_prompt
+from .engine import Engine, SamplingParams, prepare_prompt
SUPPORTED_MODELS = [
"meta-llama/Meta-Llama-3-8B-Instruct",
@@ -48,7 +48,7 @@ class HFSamplingParams(SamplingParams):
# pylint: disable=too-few-public-methods
-class HFEngine(EngineProtocol[HFSamplingParams]): # pylint: disable=invalid-name
+class HFEngine(Engine[HFSamplingParams]): # pylint: disable=invalid-name
"""LLM Instruct Model Inference"""
def __init__(
@@ -100,33 +100,3 @@ async def generate(
yield reply_prefix
for new_text in streamer:
yield cast(str, new_text)
-
- # pylint: disable=too-many-arguments
- async def __call__(
- self,
- conversation_id: str,
- prompt: str,
- system_prompt: str,
- reply_prefix: str,
- sampling_params: HFSamplingParams,
- ):
- if system_prompt != "":
- self.conversations[conversation_id].append(
- {
- "role": "system",
- "content": system_prompt,
- }
- )
- self.conversations[conversation_id].append({"role": "user", "content": prompt})
- reply: str = ""
- async for response in self.generate(
- self.conversations[conversation_id],
- reply_prefix,
- sampling_params,
- ):
- reply += response
- yield response
- self.conversations[conversation_id].append(
- {"role": "assistant", "content": reply}
- )
- yield ""
diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py b/src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py
index c45849f..24ed4c8 100644
--- a/src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py
+++ b/src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py
@@ -6,7 +6,7 @@
from pydantic import Field
from transformers import AutoTokenizer # type: ignore
-from .utils import EngineProtocol, SamplingParams, prepare_prompt
+from .engine import Engine, SamplingParams, prepare_prompt
class LlamaCppSamplingParams(SamplingParams):
@@ -17,7 +17,7 @@ class LlamaCppSamplingParams(SamplingParams):
presence_penalty: float = Field(default=0.0, description="Presence penalty")
-class LlamaCppEngine(EngineProtocol[LlamaCppSamplingParams]):
+class LlamaCppEngine(Engine[LlamaCppSamplingParams]):
"""LLM Instruct Model Inference"""
def __init__( # pylint: disable=too-many-arguments
@@ -63,33 +63,3 @@ async def generate(
):
output = cast(CreateCompletionStreamResponse, output)
yield output["choices"][0]["text"]
-
- # pylint: disable=too-many-arguments
- async def __call__(
- self,
- conversation_id: str,
- prompt: str,
- system_prompt: str,
- reply_prefix: str,
- sampling_params: LlamaCppSamplingParams,
- ):
- if system_prompt != "":
- self.conversations[conversation_id].append(
- {
- "role": "system",
- "content": system_prompt,
- }
- )
- self.conversations[conversation_id].append({"role": "user", "content": prompt})
- reply: str = ""
- async for response in self.generate(
- self.conversations[conversation_id],
- reply_prefix,
- sampling_params,
- ):
- reply += response
- yield response
- self.conversations[conversation_id].append(
- {"role": "assistant", "content": reply}
- )
- yield ""
diff --git a/src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py b/src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py
index 5592af7..5df0e1e 100644
--- a/src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py
+++ b/src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py
@@ -8,7 +8,7 @@
from huggingface_hub import hf_hub_download
from pydantic import Field
-from .utils import ConcurrentEngineProtocol, SamplingParams, prepare_prompt
+from .engine import ConcurrentEngine, SamplingParams, prepare_prompt
class VLLMSamplingParams(SamplingParams):
@@ -19,7 +19,7 @@ class VLLMSamplingParams(SamplingParams):
presence_penalty: float = Field(default=0.0, description="Presence penalty")
-class VLLMEngine(ConcurrentEngineProtocol[VLLMSamplingParams]):
+class VLLMEngine(ConcurrentEngine[VLLMSamplingParams]):
"""LLM Instruct Model Inference using VLLM"""
def __init__(
@@ -64,10 +64,10 @@ def get_sampling_params(
async def generate(
self,
- task_id: str,
messages: list[dict],
reply_prefix: str | None,
sampling_params: VLLMSamplingParams,
+ task_id: str,
):
"""Generate text from prompt"""
prompt = prepare_prompt(self.tokenizer, messages, reply_prefix)
@@ -82,37 +82,3 @@ async def generate(
current_len = len(output.outputs[0].text)
if output.finished:
break
-
- # pylint: disable=too-many-arguments
- async def __call__(
- self,
- conversation_id: str,
- prompt: str,
- system_prompt: str,
- reply_prefix: str,
- sampling_params: VLLMSamplingParams,
- task_id: str,
- ):
- if conversation_id not in self.conversations:
- self.conversations[conversation_id] = []
- if system_prompt != "":
- self.conversations[conversation_id].append(
- {
- "role": "system",
- "content": system_prompt,
- }
- )
- self.conversations[conversation_id].append({"role": "user", "content": prompt})
- reply: str = ""
- async for response in self.generate(
- task_id,
- self.conversations[conversation_id],
- reply_prefix,
- sampling_params,
- ):
- reply += response
- yield response
- self.conversations[conversation_id].append(
- {"role": "assistant", "content": reply}
- )
- yield ""
diff --git a/src/AGISwarm/llm_instruct_ms/typing.py b/src/AGISwarm/llm_instruct_ms/typing.py
index 8627b6f..fca48ec 100644
--- a/src/AGISwarm/llm_instruct_ms/typing.py
+++ b/src/AGISwarm/llm_instruct_ms/typing.py
@@ -38,6 +38,7 @@ class VLLMConfig(ModelConfig):
filename: str | None = None
+
class HFConfig(ModelConfig):
"""HF settings"""