From 180004b7dd3bab8f5a06502ac2c5418c8fe5a3c6 Mon Sep 17 00:00:00 2001 From: Denis Diachkov Date: Sun, 8 Sep 2024 14:55:26 +0200 Subject: [PATCH] separate diffusion logic from app logic --- src/AGISwarm/text2image_ms/__main__.py | 37 ++++++++++--------- .../text2image_ms/diffusion_pipeline.py | 24 +++++++++++- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/src/AGISwarm/text2image_ms/__main__.py b/src/AGISwarm/text2image_ms/__main__.py index 33b8fca..e8605cd 100644 --- a/src/AGISwarm/text2image_ms/__main__.py +++ b/src/AGISwarm/text2image_ms/__main__.py @@ -6,14 +6,11 @@ import asyncio import base64 import logging -import traceback from functools import partial from io import BytesIO from pathlib import Path import hydra -import numpy as np -import torch import uvicorn from AGISwarm.asyncio_queue_manager import AsyncIOQueueManager, TaskStatus from fastapi import APIRouter, FastAPI, WebSocket @@ -35,7 +32,7 @@ class Text2ImageApp: def __init__(self, config: DiffusionConfig, gui_config: GUIConfig): self.app = FastAPI() self.setup_routes() - self.queue_manager = AsyncIOQueueManager() + self.queue_manager = AsyncIOQueueManager(sleep_time=0.0001) self.text2image_pipeline = Text2ImagePipeline(config) self.latent_update_frequency = gui_config.latent_update_frequency self.start_abort_lock = asyncio.Lock() @@ -57,15 +54,23 @@ def setup_routes(self): self.app.include_router(self.ws_router) @staticmethod - def send_image(websocket: WebSocket, image: Image.Image, **kwargs): + def pil2dataurl(image: Image.Image): """ - Send an image to the client. + Convert a PIL image to a data URL. """ image_io = BytesIO() image.save(image_io, "PNG") dataurl = "data:image/png;base64," + base64.b64encode( image_io.getvalue() ).decode("ascii") + return dataurl + + @staticmethod + def send_image(websocket: WebSocket, image: Image.Image, **kwargs): + """ + Send an image to the client. + """ + dataurl = Text2ImageApp.pil2dataurl(image) asyncio_run( websocket.send_json( { @@ -84,20 +89,19 @@ def diffusion_pipeline_step_callback( abort_event: asyncio.Event, total_steps: int, latent_update_frequency: int, - pipeline, + pipeline: Text2ImagePipeline, + _diffusers_pipeline, step: int, - _: int, + _timestamp: int, callback_kwargs: dict, ): """Callback для StableDiffusionPipeline""" if abort_event.is_set(): raise asyncio.CancelledError("Diffusion pipeline aborted") - asyncio_run(asyncio.sleep(0.01)) + asyncio_run(asyncio.sleep(0.0001)) if step == 0 or step != total_steps and step % latent_update_frequency != 0: - return {"latents": callback_kwargs["latents"]} - with torch.no_grad(): - image = pipeline.decode_latents(callback_kwargs["latents"].clone())[0] - image = Image.fromarray((image * 255).astype(np.uint8)) + return callback_kwargs + image = pipeline.decode_latents(callback_kwargs["latents"]) Text2ImageApp.send_image( websocket, image, @@ -106,7 +110,7 @@ def diffusion_pipeline_step_callback( step=step, total_steps=total_steps, ) - return {"latents": callback_kwargs["latents"]} + return callback_kwargs async def generate(self, websocket: WebSocket): """ @@ -117,7 +121,7 @@ async def generate(self, websocket: WebSocket): try: while True: - await asyncio.sleep(0.01) + await asyncio.sleep(0.0001) data = await websocket.receive_text() async with self.start_abort_lock: # Read generation config @@ -144,6 +148,7 @@ async def generate(self, websocket: WebSocket): abort_event, gen_config.num_inference_steps, self.latent_update_frequency, + self.text2image_pipeline, ) # Start the generation task @@ -179,7 +184,6 @@ async def generate(self, websocket: WebSocket): ) except Exception as e: # pylint: disable=broad-except logging.error(e) - traceback.print_exc() await websocket.send_json( { "status": TaskStatus.ERROR, @@ -188,7 +192,6 @@ async def generate(self, websocket: WebSocket): ) except Exception as e: # pylint: disable=broad-except logging.error(e) - traceback.print_exc() await websocket.send_json( { "status": TaskStatus.ERROR, diff --git a/src/AGISwarm/text2image_ms/diffusion_pipeline.py b/src/AGISwarm/text2image_ms/diffusion_pipeline.py index b357277..caa2cb5 100644 --- a/src/AGISwarm/text2image_ms/diffusion_pipeline.py +++ b/src/AGISwarm/text2image_ms/diffusion_pipeline.py @@ -71,15 +71,37 @@ def generate( callback_on_step_end (Optional[Callable[[dict], None]): The callback function to call on each step end. """ + generator = None + if gen_config.seed != -1: + generator = torch.Generator() + generator.manual_seed(gen_config.seed) return { "image": self.pipeline( prompt=gen_config.prompt, negative_prompt=gen_config.negative_prompt, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, - generator=torch.Generator().manual_seed(gen_config.seed), + generator=generator, width=gen_config.width, height=gen_config.height, callback_on_step_end=callback_on_step_end, )["images"][0] } + + def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"): + """ + Decode the latents to an image. + + Args: + latents (torch.Tensor): The latents to decode to an image. + output_type (str, optional): The output type of the image. Defaults to "pil". + options: "pil", "np", "torch". + """ + image = self.pipeline.vae.decode( + latents / self.pipeline.vae.config.scaling_factor, + return_dict=False, + )[0] + image = self.pipeline.image_processor.postprocess( + image, output_type=output_type, do_denormalize=[True] + )[0] + return image