Skip to content

Commit

Permalink
Merge pull request #3 from AGISwarm/feature/llava
Browse files Browse the repository at this point in the history
Feature/llava
  • Loading branch information
DenisDiachkov authored Oct 8, 2024
2 parents 90fd660 + 68933b6 commit 37a3233
Show file tree
Hide file tree
Showing 15 changed files with 466 additions and 205 deletions.
12 changes: 0 additions & 12 deletions config/NousResearch-Hermes-3-Llama-3.1-8B-GGUF.yaml

This file was deleted.

17 changes: 17 additions & 0 deletions config/chuanli11-Llama-3.2-3B-Instruct-uncensored.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
%YAML 1.1
---

hf_model_name: !!str chuanli11/Llama-3.2-3B-Instruct-uncensored # "alpindale/Llama-3.2-11B-Vision-Instruct"
tokenizer_name: chuanli11/Llama-3.2-3B-Instruct-uncensored # alpindale/Llama-3.2-11B-Vision-Instruct

engine: !!str VLLMEngine
engine_config:
dtype: !!str float16
gpu_memory_utilization: !!float 1.0
tensor_parallel_size: !!int 2
max_model_len: 32768


defaults:
- gui_config: default
- uvicorn_config: default
17 changes: 17 additions & 0 deletions config/deepvk-llava-saiga-8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
%YAML 1.1
---

hf_model_name: !!str deepvk/llava-saiga-8b # "alpindale/Llama-3.2-11B-Vision-Instruct"
tokenizer_name: deepvk/llava-saiga-8b # alpindale/Llama-3.2-11B-Vision-Instruct

engine: !!str VLLMEngine
engine_config:
dtype: !!str float16
gpu_memory_utilization: !!float 1.0
tensor_parallel_size: !!int 2
max_model_len: 8192
limit_mm_per_prompt: {"image": 1}

defaults:
- gui_config: default
- uvicorn_config: default
15 changes: 15 additions & 0 deletions config/meta-llama-Meta-Llama-3.2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
%YAML 1.1
---

hf_model_name: !!str "meta-llama/Llama-3.2-1B-Instruct"
tokenizer_name: !!str "meta-llama/Llama-3.2-1B-Instruct"

engine: !!str VLLMEngine
engine_config:
dtype: !!str float16
max_model_len: 8192
gpu_memory_utilization: !!float 0.6

defaults:
- gui_config: default
- uvicorn_config: default
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ keywords = ["sample", "setuptools", "development"]
classifiers = ["Programming Language :: Python :: 3"]
dependencies = [
'numpy~=1.26.0',
"fastapi~=0.111.0",
"fastapi~=0.114.0",
"uvicorn~=0.29.0",
"pydantic~=2.9.0",
"hydra-core~=1.3.2",
"AGISwarm.asyncio_queue_manager",
]
[project.optional-dependencies]
vllm = ["vllm==0.5.5"]
vllm = ["vllm==0.6.2"]
transformers = ["transformers"]
llama-cpp = ["llama_cpp_python"]
backends = [
Expand All @@ -34,11 +34,11 @@ GUI = ["jinja2"]

test = ['pytest~=8.2.1']
analyze = [
'pyright~=1.1.364',
'pylint~=3.2.2',
'bandit~=1.7.8',
'black~=24.4.2',
'isort~=5.13.2',
'pyright',
'pylint',
'bandit',
'black',
'isort',
]
build = ['setuptools', 'wheel', 'build']
publish = ['twine']
Expand Down
124 changes: 56 additions & 68 deletions src/AGISwarm/llm_instruct_ms/app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Main module for the LLM instruct microservice"""

import asyncio
import base64
import logging
import re
import uuid
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, cast

Expand All @@ -12,6 +15,7 @@
from fastapi.staticfiles import StaticFiles
from jinja2 import Environment, FileSystemLoader
from omegaconf import OmegaConf
from PIL import Image
from pydantic import BaseModel

from .llm_engines import ConcurrentEngine, Engine
Expand All @@ -38,8 +42,8 @@ def __init__(self, config: LLMInstructConfig):
)
self.sampling_settings_cls = ENGINE_SAMPLING_PARAMS_MAP[config.engine]
self.queue_manager = AsyncIOQueueManager(
max_concurrent_tasks=5,
sleep_time=0,
max_concurrent_tasks=2,
sleep_time=0.001,
)
self.start_abort_lock = asyncio.Lock()
self.setup_routes()
Expand Down Expand Up @@ -79,6 +83,28 @@ async def gui(self):
)
return FileResponse(Path(__file__).parent / "gui" / "current_index.html")

@staticmethod
def remove_mime_header(image_data):
"""
Remove the MIME type header from the image data and return the raw base64 data.
:return: the raw base64 data
"""
# Regular expression to match the MIME type header
mime_pattern = r"^data:image/([a-zA-Z]+);base64,"
match = re.match(mime_pattern, image_data)

if match:
# Remove the header to get the raw base64 data
base64_data = re.sub(mime_pattern, "", image_data)
return base64_data
# If there's no match, assume it's already raw base64 data
return image_data

def base64_to_image(self, image: str) -> Image.Image:
"""Convert base64 image to PIL image"""
image = self.remove_mime_header(image)
return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")

async def generate(self, websocket: WebSocket): # type: ignore
"""WebSocket endpoint"""
await websocket.accept()
Expand All @@ -87,78 +113,41 @@ async def generate(self, websocket: WebSocket): # type: ignore
while True:
data: Dict[str, Any] = await websocket.receive_json()
gen_config = SamplingConfig(data)
sampling_dict = self.sampling_settings_cls.model_validate(
gen_config,
strict=False,
)
image: Image.Image | None = (
self.base64_to_image(gen_config.image) if gen_config.image else None
)
# Enqueue the task (without starting it)
queued_task = self.queue_manager.queued_generator(
queued_task = self.queue_manager.queued_task(
self.llm_pipeline.__call__,
pass_task_id=isinstance(
self.llm_pipeline, ConcurrentEngine # type: ignore
),
warnings=(
["Image input not supported by this model"]
if image and not self.llm_pipeline.image_prompt_enabled
else None
),
raise_on_error=False,
print_error_tracebacks=True,
)
# task_id and interrupt_event are created by the queued_generator
task_id = queued_task.task_id
await websocket.send_json(
{
"status": TaskStatus.STARTING,
"task_id": task_id,
}
)
# Start the generation task
sampling_dict = self.sampling_settings_cls.model_validate(
gen_config,
strict=False,
)
try:
async for step_info in queued_task(
conversation_id,
gen_config.prompt,
gen_config.system_prompt,
gen_config.reply_prefix,
sampling_dict,
):
await asyncio.sleep(0)
if "status" not in step_info: # Task's return value.
await websocket.send_json(
{
"task_id": task_id,
"status": TaskStatus.RUNNING,
"tokens": step_info,
}
)
continue
if (
step_info["status"] == TaskStatus.WAITING
): # Queuing info returned
await websocket.send_json(step_info)
continue
if (
step_info["status"] != TaskStatus.RUNNING
): # Queuing info returned
await websocket.send_json(step_info)
break
await websocket.send_json(
{
"task_id": task_id,
"status": TaskStatus.FINISHED,
}
)
except asyncio.CancelledError as e:
logging.info(e)
await websocket.send_json(
{
"status": TaskStatus.ABORTED,
"task_id": task_id,
}
)
except Exception as e: # pylint: disable=broad-except
logging.error(e)
await websocket.send_json(
{
"status": TaskStatus.ERROR,
"message": str(e), ### loggging
}
)
async for step_info in queued_task(
conversation_id,
gen_config.prompt,
gen_config.system_prompt,
gen_config.reply_prefix,
image,
sampling_dict,
):
if step_info["status"] == TaskStatus.ERROR:
step_info["content"] = None
await websocket.send_json(step_info)
except WebSocketDisconnect:
print("Client disconnected", flush=True)
logging.info("Client %s disconnected", conversation_id)
finally:
self.llm_pipeline.conversations.pop(conversation_id, None)
await websocket.close()
Expand All @@ -170,7 +159,6 @@ class AbortRequest(BaseModel):

async def abort(self, request: AbortRequest):
"""Abort generation"""
print(f"ENTER ABORT Aborting request {request.task_id}")
async with self.start_abort_lock:
print(f"Aborting request {request.task_id}")
logging.info("Aborting task %s", request.task_id)
await self.queue_manager.abort_task(request.task_id)
17 changes: 13 additions & 4 deletions src/AGISwarm/llm_instruct_ms/gui/current_index.html
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,18 @@ <h1 id="title"></h1>
</div>
<div id="chat-output" class="chat-output"></div>
<div class="input-container">
<div id="image-preview"></div>
<textarea id="prompt" class="form-control" placeholder="Enter your message" rows="1"></textarea>
<button onclick="sendButtonClick()" class="btn btn-primary" id="send-btn">Send</button>
<div class="input-group-append">
<button class="btn btn-outline-secondary" type="button" id="attach-image">
<i class="fas fa-paperclip"></i>
</button>
</div>
<div class="input-group-append">
<button onclick="sendButtonClick()" class="btn btn-primary" id="send-btn">
<i class="fas fa-paper-plane"></i>
</button>
</div>
</div>
<div class="input-container">
<textarea id="reply_prefix" class="form-control" placeholder="Enter reply prefix" value=""></textarea>
Expand All @@ -31,7 +41,7 @@ <h1 id="title"></h1>
value=></textarea>
</div>
</div>

<div class="config-container">

<div class="form-group">
Expand All @@ -47,8 +57,7 @@ <h1 id="title"></h1>
</div>
<div class="form-group">
<span for="top_p">Top P:</span>
<input type="number" id="top_p" class="form-control" min="1e-9" max="1.0" step="0.1"
value="0.95">
<input type="number" id="top_p" class="form-control" min="1e-9" max="1.0" step="0.1" value="0.95">

</div>
<div class="form-group">
Expand Down
17 changes: 13 additions & 4 deletions src/AGISwarm/llm_instruct_ms/gui/jinja2.html
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,18 @@ <h1 id="title">{{gui_title}}</h1>
</div>
<div id="chat-output" class="chat-output"></div>
<div class="input-container">
<div id="image-preview"></div>
<textarea id="prompt" class="form-control" placeholder="Enter your message" rows="1"></textarea>
<button onclick="sendButtonClick()" class="btn btn-primary" id="send-btn">Send</button>
<div class="input-group-append">
<button class="btn btn-outline-secondary" type="button" id="attach-image">
<i class="fas fa-paperclip"></i>
</button>
</div>
<div class="input-group-append">
<button onclick="sendButtonClick()" class="btn btn-primary" id="send-btn">
<i class="fas fa-paper-plane"></i>
</button>
</div>
</div>
<div class="input-container">
<textarea id="reply_prefix" class="form-control" placeholder="Enter reply prefix" value=""></textarea>
Expand All @@ -31,7 +41,7 @@ <h1 id="title">{{gui_title}}</h1>
value={{system_prompt}}></textarea>
</div>
</div>

<div class="config-container">

<div class="form-group">
Expand All @@ -47,8 +57,7 @@ <h1 id="title">{{gui_title}}</h1>
</div>
<div class="form-group">
<span for="top_p">Top P:</span>
<input type="number" id="top_p" class="form-control" min="1e-9" max="1.0" step="0.1"
value="{{top_p}}">
<input type="number" id="top_p" class="form-control" min="1e-9" max="1.0" step="0.1" value="{{top_p}}">

</div>
<div class="form-group">
Expand Down
Loading

0 comments on commit 37a3233

Please sign in to comment.