Skip to content

Commit

Permalink
AsyincIOQueueManager related refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Oct 8, 2024
1 parent 664fd15 commit b7144b9
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 104 deletions.
4 changes: 2 additions & 2 deletions src/AGISwarm/text2image_ms/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import uvicorn

from .app import Text2ImageApp
from .typing import Config
from .typing import Text2ImageConfig


@hydra.main(
version_base=None,
config_name="config",
config_path=str(Path(os.getcwd()) / "config"),
)
def main(config: Config):
def main(config: Text2ImageConfig):
"""
The main function for the Text2Image service.
"""
Expand Down
112 changes: 44 additions & 68 deletions src/AGISwarm/text2image_ms/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from pathlib import Path

from AGISwarm.asyncio_queue_manager import AsyncIOQueueManager, TaskStatus
from fastapi import APIRouter, FastAPI, WebSocket
from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
from pydantic.main import BaseModel
from pydantic import BaseModel

from .diffusion_pipeline import Text2ImagePipeline
from .typing import DiffusionConfig, GUIConfig, Text2ImageGenerationConfig
Expand All @@ -32,7 +32,8 @@ def __init__(
):
self.app = FastAPI()
self.setup_routes()
self.queue_manager = AsyncIOQueueManager(sleep_time=0.0001)
self.sleep_time = 0.001
self.queue_manager = AsyncIOQueueManager(sleep_time=self.sleep_time)
self.text2image_pipeline = Text2ImagePipeline(hf_model_name, config)
self.latent_update_frequency = gui_config.latent_update_frequency
self.start_abort_lock = asyncio.Lock()
Expand Down Expand Up @@ -66,18 +67,22 @@ def pil2dataurl(image: Image.Image):
return dataurl

@staticmethod
def send_image(websocket: WebSocket, image: Image.Image, **kwargs):
def send_image(websocket: WebSocket, task_id: str, image: Image.Image, **kwargs):
"""
Send an image to the client.
"""
dataurl = Text2ImageApp.pil2dataurl(image)
asyncio_run(
websocket.send_json(
{
"image": dataurl,
"shape": image.size,
"task_id": task_id,
"status": TaskStatus.RUNNING,
"content": {
"image": dataurl,
"shape": image.size,
}
| kwargs,
}
| kwargs
)
)

Expand All @@ -90,6 +95,7 @@ def diffusion_pipeline_step_callback(
total_steps: int,
latent_update_frequency: int,
pipeline: Text2ImagePipeline,
sleep_time: float,
_diffusers_pipeline,
step: int,
_timestamp: int,
Expand All @@ -98,16 +104,31 @@ def diffusion_pipeline_step_callback(
"""Callback для StableDiffusionPipeline"""
if abort_event.is_set():
raise asyncio.CancelledError("Diffusion pipeline aborted")
asyncio_run(asyncio.sleep(0.0001))
if step == 0 or step != total_steps and step % latent_update_frequency != 0:
asyncio_run(asyncio.sleep(sleep_time))
if (
step == 0
or step % latent_update_frequency != latent_update_frequency - 1
and step != total_steps - 1
):
asyncio_run(
websocket.send_json(
{
"task_id": task_id,
"status": TaskStatus.RUNNING,
"content": {
"step": step + 1,
"total_steps": total_steps,
},
}
)
)
return callback_kwargs
image = pipeline.decode_latents(callback_kwargs["latents"])
Text2ImageApp.send_image(
websocket,
task_id,
image,
task_id=task_id,
status=TaskStatus.RUNNING,
step=step,
step=step + 1,
total_steps=total_steps,
)
return callback_kwargs
Expand All @@ -121,7 +142,6 @@ async def generate(self, websocket: WebSocket):

try:
while True:
await asyncio.sleep(0.0001)
data = await websocket.receive_text()
async with self.start_abort_lock:
# Read generation config
Expand All @@ -133,12 +153,6 @@ async def generate(self, websocket: WebSocket):
# task_id and interrupt_event are created by the queued_generator
task_id = queued_task.task_id
abort_event = self.queue_manager.abort_map[task_id]
await websocket.send_json(
{
"status": TaskStatus.STARTING,
"task_id": task_id,
}
)

# Diffusion step callback
callback_on_step_end = partial(
Expand All @@ -149,55 +163,18 @@ async def generate(self, websocket: WebSocket):
gen_config.num_inference_steps,
self.latent_update_frequency,
self.text2image_pipeline,
self.sleep_time,
)

# Start the generation task
try:
async for step_info in queued_task(
gen_config, callback_on_step_end
):
if "status" not in step_info: # Task's return value.
Text2ImageApp.send_image(
websocket,
step_info["image"],
task_id=task_id,
status=TaskStatus.FINISHED,
)
break
if (
step_info["status"] == TaskStatus.WAITING
): # Queuing info returned
await websocket.send_json(step_info)
continue
if (
step_info["status"] != TaskStatus.RUNNING
): # Queuing info returned
await websocket.send_json(step_info)
break
except asyncio.CancelledError as e:
logging.info(e)
await websocket.send_json(
{
"status": TaskStatus.ABORTED,
"task_id": task_id,
}
)
except Exception as e: # pylint: disable=broad-except
logging.error(e)
await websocket.send_json(
{
"status": TaskStatus.ERROR,
"message": str(e), ### loggging
}
)
except Exception as e: # pylint: disable=broad-except
logging.error(e)
await websocket.send_json(
{
"status": TaskStatus.ERROR,
"message": str(e), ### loggging
}
)
async for step_info in queued_task(gen_config, callback_on_step_end):
if step_info["status"] == TaskStatus.ERROR:
step_info["content"] = None
if step_info["status"] == TaskStatus.RUNNING:
# all running steps are managed by the callback
continue
await websocket.send_json(step_info)
except WebSocketDisconnect:
logging.info("Client disconnected")
finally:
await websocket.close()

Expand All @@ -208,9 +185,8 @@ class AbortRequest(BaseModel):

async def abort(self, request: AbortRequest):
"""Abort generation"""
print(f"ENTER ABORT Aborting request {request.task_id}")
async with self.start_abort_lock:
print(f"Aborting request {request.task_id}")
logging.info("Aborting request %s", request.task_id)
await self.queue_manager.abort_task(request.task_id)

async def gui(self):
Expand Down
4 changes: 1 addition & 3 deletions src/AGISwarm/text2image_ms/diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def generate(
if gen_config.seed != -1:
generator = torch.Generator()
generator.manual_seed(gen_config.seed)
return {
"image": self.pipeline(
return self.pipeline(
prompt=gen_config.prompt,
negative_prompt=gen_config.negative_prompt,
num_inference_steps=gen_config.num_inference_steps,
Expand All @@ -86,7 +85,6 @@ def generate(
height=gen_config.height,
callback_on_step_end=callback_on_step_end,
)["images"][0]
}

def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
"""
Expand Down
6 changes: 4 additions & 2 deletions src/AGISwarm/text2image_ms/gui/gui.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Stable Diffusion Streaming Interface</title>
<title>Stable Diffusion Interface</title>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC" crossorigin="anonymous">
<!--./style.css -->
Expand All @@ -20,7 +20,9 @@ <h1 id="title">Stable Diffusion Streaming</h1>
<div id="chat-output" class="chat-output"></div>
<div class="input-container">
<textarea id="prompt" class="form-control" placeholder="Enter your message" rows="1"></textarea>
<button onclick="sendButtonClick()" class="btn btn-primary" id="send-btn">Send</button>
<button onclick="sendButtonClick()" class="btn btn-primary" id="send-btn">
<i class="fa fa-paper-plane"></i>
</button>
</div>
<div class="input-container">
<textarea id="negative_prompt" class="form-control" placeholder="Enter your negative prompt"></textarea>
Expand Down
53 changes: 31 additions & 22 deletions src/AGISwarm/text2image_ms/gui/script.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
let ws = new WebSocket(WEBSOCKET_URL);
let abort_url = ABORT_URL;
let currentRequestID = '';
let idle = true;

ws.onopen = function () {
console.log("WebSocket connection established");
Expand All @@ -13,26 +14,30 @@ ws.onmessage = function (event) {
switch (data["status"]) {
case "starting":
updateStatus("Starting generation");
disableGenerateButton();
break;
case "waiting":
updateStatus("Waiting for " + data["queue_pos"] + " requests to finish");
queue_pos = data["content"]["queue_pos"];
updateStatus("<br>" + "<span style='color:blue;'>You are in position " + queue_pos + " in the queue</span>" + "<br>");
enableAbortButton();
break;
case "running":
updateStatus(data["step"] + " / " + data["total_steps"]);
updateImage(data["image"]);
updateStatus(data["content"]["step"] + " / " + data["content"]["total_steps"]);
if ("image" in data["content"]){
updateImage(data["content"]["image"]);
}
enableAbortButton();
break;
case "finished":
updateStatus("");
updateImage(data["image"]);
enableGenerateButton();
break;
case "aborted":
updateStatus("Generation aborted");
updateStatus("<br>" + "<span style='color:red;'>Generation aborted</span>" + "<br>");
enableGenerateButton();
break;
case "error":
console.log("Error in the server");
updateStatus("Error in the server");
updateStatus("<br>" + "<span style='color:red;'>Error in generation</span>" + "<br>");
enableGenerateButton();
break;
}
Expand Down Expand Up @@ -117,35 +122,30 @@ function abortGeneration() {
})
.then(response => response.text())
.catch(error => console.error('Error aborting generation:', error));
// Enable the send button
document.getElementById('send-btn').style.backgroundColor = "#363d46";
document.getElementById('send-btn').textContent = "Send";
document.getElementById('send-btn').disabled = false;
console.log("Generation aborted.");
}

function sendButtonClick() {
document.getElementById('send-btn').disabled = true;
if (document.getElementById('send-btn').textContent === "Send") {
if (idle) {
sendMessage();
}
else if (document.getElementById('send-btn').textContent === "Abort") {
else {
abortGeneration();
}
}

function enterSend(event) {
if (event.key === 'Enter' && !event.ctrlKey) {
event.preventDefault();
if (document.getElementById('send-btn').textContent === "Send") {
if (idle){
document.getElementById('send-btn').disabled = true;
sendMessage();
}
} else if (event.key === 'Enter' && event.ctrlKey) {
// Allow new line with Ctrl+Enter
this.value += '\n';
}
};
}

document.getElementById('prompt').addEventListener('keydown', enterSend);
document.getElementById('negative_prompt').addEventListener('keydown', enterSend);
Expand All @@ -154,11 +154,11 @@ function updateStatus(message) {
botMessage = get_bot_message_container();
statusElement = botMessage.querySelector('#status');
if (!statusElement) {
statusElement = document.createElement('div');
statusElement = document.createElement('pre');
statusElement.id = 'status';
botMessage.appendChild(statusElement);
}
statusElement.textContent = message;
statusElement.innerHTML = message;
};


Expand Down Expand Up @@ -199,17 +199,26 @@ function updateImage(base64Image) {
};

function disableGenerateButton() {
document.getElementById('send-btn').style.backgroundColor = "#808080";
document.getElementById('send-btn').textContent = "Abort";
document.getElementById('send-btn').disabled = false;
document.getElementById('send-btn').disabled = true;
};

function enableGenerateButton() {
document.getElementById('send-btn').style.backgroundColor = "#363d46";
document.getElementById('send-btn').textContent = "Send";
document.getElementById('send-btn').innerHTML = '<i class="fa fa-paper-plane"></i>';
document.getElementById('send-btn').disabled = false;
idle = true;
};

function enableAbortButton() {
if (!idle) {
return;
}
document.getElementById('send-btn').style.backgroundColor = "#363d46";
document.getElementById('send-btn').innerHTML = '<i class="fa fa-stop"></i>';
document.getElementById('send-btn').disabled = false;
idle = false;
}

function resetForm() {
document.getElementById('num_inference_steps').value = '50';
document.getElementById('guidance_scale').value = '7.5';
Expand Down
Loading

0 comments on commit b7144b9

Please sign in to comment.