Skip to content

Commit

Permalink
Chat interface, proper abort handling
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Sep 7, 2024
1 parent 1df2fed commit ca28ddf
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 181 deletions.
6 changes: 1 addition & 5 deletions conf/diffusion_config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,4 @@ dtype: "float32"
device: "cuda:0"
safety_checker:
requires_safety_checker: True
low_cpu_mem_usage: False

hydra:
job:
chdir: false
low_cpu_mem_usage: True
101 changes: 48 additions & 53 deletions src/AGISwarm/text2image_ms/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@
import asyncio
import base64
import logging
import multiprocessing as mp
import traceback
from functools import partial
from io import BytesIO
from pathlib import Path

import hydra
import nest_asyncio
import numpy as np
import torch
import uvicorn
from AGISwarm.asyncio_queue_manager import AsyncIOQueueManager, RequestStatus
from AGISwarm.asyncio_queue_manager import AsyncIOQueueManager, TaskStatus
from fastapi import APIRouter, FastAPI, WebSocket
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
Expand All @@ -26,30 +24,7 @@

from .diffusion_pipeline import Text2ImagePipeline
from .typing import Config, DiffusionConfig, GUIConfig, Text2ImageGenerationConfig


def _to_task(future: asyncio.Future, as_task: bool, loop: asyncio.AbstractEventLoop):
if not as_task or isinstance(future, asyncio.Task):
return future
return loop.create_task(future)


def asyncio_run(future, as_task=True):
"""
A better implementation of `asyncio.run`.
:param future: A future or task or call of an async method.
:param as_task: Forces the future to be scheduled as task (needed for e.g. aiohttp).
"""

try:
loop = asyncio.get_running_loop()
except RuntimeError: # no event loop running:
loop = asyncio.new_event_loop()
return loop.run_until_complete(_to_task(future, as_task, loop))
else:
nest_asyncio.apply(loop)
return asyncio.run(_to_task(future, as_task, loop))
from .utils import asyncio_run


class Text2ImageApp:
Expand All @@ -63,6 +38,7 @@ def __init__(self, config: DiffusionConfig, gui_config: GUIConfig):
self.queue_manager = AsyncIOQueueManager()
self.text2image_pipeline = Text2ImagePipeline(config)
self.latent_update_frequency = gui_config.latent_update_frequency
self.start_abort_lock = asyncio.Lock()

def setup_routes(self):
"""
Expand All @@ -77,7 +53,7 @@ def setup_routes(self):

self.ws_router = APIRouter()
self.ws_router.add_websocket_route("/ws", self.generate)
self.ws_router.post("/abort")(self.abort)
self.app.post("/abort")(self.abort)
self.app.include_router(self.ws_router)

@staticmethod
Expand All @@ -101,20 +77,22 @@ def send_image(websocket: WebSocket, image: Image.Image, **kwargs):
)

@staticmethod
# pylint: disable=too-many-arguments
def diffusion_pipeline_step_callback(
websocket: WebSocket,
request_id: str,
task_id: str,
abort_event: asyncio.Event,
total_steps: int,
latent_update_frequency: int,
pipeline,
step: int,
timestep: int,
_: int,
callback_kwargs: dict,
):
"""Callback для StableDiffusionPipeline"""
if abort_event.is_set():
raise asyncio.CancelledError("Diffusion pipeline aborted")
asyncio_run(asyncio.sleep(0.01))
if step == 0 or step != total_steps and step % latent_update_frequency != 0:
return {"latents": callback_kwargs["latents"]}
with torch.no_grad():
Expand All @@ -123,8 +101,8 @@ def diffusion_pipeline_step_callback(
Text2ImageApp.send_image(
websocket,
image,
request_id=request_id,
status=RequestStatus.RUNNING,
task_id=task_id,
status=TaskStatus.RUNNING,
step=step,
total_steps=total_steps,
)
Expand All @@ -141,22 +119,28 @@ async def generate(self, websocket: WebSocket):
while True:
await asyncio.sleep(0.01)
data = await websocket.receive_text()
# Read generation config
gen_config = Text2ImageGenerationConfig.model_validate_json(data)
# Enqueue the task (without starting it)
queued_task = self.queue_manager.queued_task(
self.text2image_pipeline.generate
)

# request_id and interrupt_event are created by the queued_generator
request_id = queued_task.request_id
abort_event = self.queue_manager.abort_map[request_id]
async with self.start_abort_lock:
# Read generation config
gen_config = Text2ImageGenerationConfig.model_validate_json(data)
# Enqueue the task (without starting it)
queued_task = self.queue_manager.queued_task(
self.text2image_pipeline.generate
)
# task_id and interrupt_event are created by the queued_generator
task_id = queued_task.task_id
abort_event = self.queue_manager.abort_map[task_id]
await websocket.send_json(
{
"status": TaskStatus.STARTING,
"task_id": task_id,
}
)

# Diffusion step callback
callback_on_step_end = partial(
self.diffusion_pipeline_step_callback,
websocket,
request_id,
task_id,
abort_event,
gen_config.num_inference_steps,
self.latent_update_frequency,
Expand All @@ -171,34 +155,43 @@ async def generate(self, websocket: WebSocket):
Text2ImageApp.send_image(
websocket,
step_info["image"],
request_id=request_id,
status=RequestStatus.FINISHED,
task_id=task_id,
status=TaskStatus.FINISHED,
)
break
if (
step_info["status"] == RequestStatus.WAITING
step_info["status"] == TaskStatus.WAITING
): # Queuing info returned
await websocket.send_json(step_info)
continue
if (
step_info["status"] != RequestStatus.RUNNING
step_info["status"] != TaskStatus.RUNNING
): # Queuing info returned
await websocket.send_json(step_info)
break
except asyncio.CancelledError as e:
logging.info(e)
await websocket.send_json(
{
"status": RequestStatus.ABORTED,
"request_id": request_id,
"status": TaskStatus.ABORTED,
"task_id": task_id,
}
)
except Exception as e: # pylint: disable=broad-except
logging.error(e)
traceback.print_exc()
await websocket.send_json(
{
"status": TaskStatus.ERROR,
"message": str(e), ### loggging
}
)
except Exception as e: # pylint: disable=broad-except
logging.error(e)
traceback.print_exc()
await websocket.send_json(
{
"status": RequestStatus.ERROR,
"status": TaskStatus.ERROR,
"message": str(e), ### loggging
}
)
Expand All @@ -208,12 +201,14 @@ async def generate(self, websocket: WebSocket):
class AbortRequest(BaseModel):
"""Abort request"""

request_id: str
task_id: str

async def abort(self, request: AbortRequest):
"""Abort generation"""
print(f"Aborting request {request.request_id}")
await self.queue_manager.abort_task(request.request_id)
print(f"ENTER ABORT Aborting request {request.task_id}")
async with self.start_abort_lock:
print(f"Aborting request {request.task_id}")
await self.queue_manager.abort_task(request.task_id)

async def gui(self):
"""
Expand Down
20 changes: 12 additions & 8 deletions src/AGISwarm/text2image_ms/app/gui.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,29 @@
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Stable Diffusion Streaming Interface</title>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC" crossorigin="anonymous">
<!--./style.css -->
<link rel="stylesheet" href="/static/styles.css">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">

</head>
<body>
<div class="container">
<div class="generation-container">
<div class="generation-header">
<div class="chat-container">
<div class="chat-header">
<h1 id="title">Stable Diffusion Streaming</h1>
<button id="menu-toggle" class="btn btn-secondary menu-toggle"><i class="fas fa-bars"></i></button>
</div>
<div id="image-output" class="image-output"></div>
<div id="chat-output" class="chat-output"></div>
<div class="input-container">
<textarea id="prompt" class="form-control" placeholder="Enter your prompt" rows="2"></textarea>
<textarea id="prompt" class="form-control" placeholder="Enter your message" rows="1"></textarea>
<button onclick="sendButtonClick()" class="btn btn-primary" id="send-btn">Send</button>
</div>
<textarea id="negative_prompt" class="form-control" placeholder="Enter your negative prompt" rows="2"></textarea>
<div class="input-container">
<textarea id="negative_prompt" class="form-control" placeholder="Enter your negative prompt"></textarea>
</div>
</div>

<div class="config-container">
Expand Down
Loading

0 comments on commit ca28ddf

Please sign in to comment.