Skip to content

Commit

Permalink
separate diffusion logic from app logic
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Sep 8, 2024
1 parent 71f5cdd commit 180004b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
37 changes: 20 additions & 17 deletions src/AGISwarm/text2image_ms/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@
import asyncio
import base64
import logging
import traceback
from functools import partial
from io import BytesIO
from pathlib import Path

import hydra
import numpy as np
import torch
import uvicorn
from AGISwarm.asyncio_queue_manager import AsyncIOQueueManager, TaskStatus
from fastapi import APIRouter, FastAPI, WebSocket
Expand All @@ -35,7 +32,7 @@ class Text2ImageApp:
def __init__(self, config: DiffusionConfig, gui_config: GUIConfig):
self.app = FastAPI()
self.setup_routes()
self.queue_manager = AsyncIOQueueManager()
self.queue_manager = AsyncIOQueueManager(sleep_time=0.0001)
self.text2image_pipeline = Text2ImagePipeline(config)
self.latent_update_frequency = gui_config.latent_update_frequency
self.start_abort_lock = asyncio.Lock()
Expand All @@ -57,15 +54,23 @@ def setup_routes(self):
self.app.include_router(self.ws_router)

@staticmethod
def send_image(websocket: WebSocket, image: Image.Image, **kwargs):
def pil2dataurl(image: Image.Image):
"""
Send an image to the client.
Convert a PIL image to a data URL.
"""
image_io = BytesIO()
image.save(image_io, "PNG")
dataurl = "data:image/png;base64," + base64.b64encode(
image_io.getvalue()
).decode("ascii")
return dataurl

@staticmethod
def send_image(websocket: WebSocket, image: Image.Image, **kwargs):
"""
Send an image to the client.
"""
dataurl = Text2ImageApp.pil2dataurl(image)
asyncio_run(
websocket.send_json(
{
Expand All @@ -84,20 +89,19 @@ def diffusion_pipeline_step_callback(
abort_event: asyncio.Event,
total_steps: int,
latent_update_frequency: int,
pipeline,
pipeline: Text2ImagePipeline,
_diffusers_pipeline,
step: int,
_: int,
_timestamp: int,
callback_kwargs: dict,
):
"""Callback для StableDiffusionPipeline"""
if abort_event.is_set():
raise asyncio.CancelledError("Diffusion pipeline aborted")
asyncio_run(asyncio.sleep(0.01))
asyncio_run(asyncio.sleep(0.0001))
if step == 0 or step != total_steps and step % latent_update_frequency != 0:
return {"latents": callback_kwargs["latents"]}
with torch.no_grad():
image = pipeline.decode_latents(callback_kwargs["latents"].clone())[0]
image = Image.fromarray((image * 255).astype(np.uint8))
return callback_kwargs
image = pipeline.decode_latents(callback_kwargs["latents"])
Text2ImageApp.send_image(
websocket,
image,
Expand All @@ -106,7 +110,7 @@ def diffusion_pipeline_step_callback(
step=step,
total_steps=total_steps,
)
return {"latents": callback_kwargs["latents"]}
return callback_kwargs

async def generate(self, websocket: WebSocket):
"""
Expand All @@ -117,7 +121,7 @@ async def generate(self, websocket: WebSocket):

try:
while True:
await asyncio.sleep(0.01)
await asyncio.sleep(0.0001)
data = await websocket.receive_text()
async with self.start_abort_lock:
# Read generation config
Expand All @@ -144,6 +148,7 @@ async def generate(self, websocket: WebSocket):
abort_event,
gen_config.num_inference_steps,
self.latent_update_frequency,
self.text2image_pipeline,
)

# Start the generation task
Expand Down Expand Up @@ -179,7 +184,6 @@ async def generate(self, websocket: WebSocket):
)
except Exception as e: # pylint: disable=broad-except
logging.error(e)
traceback.print_exc()
await websocket.send_json(
{
"status": TaskStatus.ERROR,
Expand All @@ -188,7 +192,6 @@ async def generate(self, websocket: WebSocket):
)
except Exception as e: # pylint: disable=broad-except
logging.error(e)
traceback.print_exc()
await websocket.send_json(
{
"status": TaskStatus.ERROR,
Expand Down
24 changes: 23 additions & 1 deletion src/AGISwarm/text2image_ms/diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,37 @@ def generate(
callback_on_step_end (Optional[Callable[[dict], None]):
The callback function to call on each step end.
"""
generator = None
if gen_config.seed != -1:
generator = torch.Generator()
generator.manual_seed(gen_config.seed)
return {
"image": self.pipeline(
prompt=gen_config.prompt,
negative_prompt=gen_config.negative_prompt,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
generator=torch.Generator().manual_seed(gen_config.seed),
generator=generator,
width=gen_config.width,
height=gen_config.height,
callback_on_step_end=callback_on_step_end,
)["images"][0]
}

def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
"""
Decode the latents to an image.
Args:
latents (torch.Tensor): The latents to decode to an image.
output_type (str, optional): The output type of the image. Defaults to "pil".
options: "pil", "np", "torch".
"""
image = self.pipeline.vae.decode(
latents / self.pipeline.vae.config.scaling_factor,
return_dict=False,
)[0]
image = self.pipeline.image_processor.postprocess(
image, output_type=output_type, do_denormalize=[True]
)[0]
return image

0 comments on commit 180004b

Please sign in to comment.