From 2c7631288b666e8b7e6193aff8f0787057cc7db4 Mon Sep 17 00:00:00 2001 From: Denis Diachkov Date: Tue, 3 Sep 2024 19:27:16 +0200 Subject: [PATCH] all works --- .gitignore | 3 +- .isort.cfg | 2 + conf/config.yaml | 10 + pyproject.toml | 11 +- pyrightconfig.json | 7 + src/AGISwarm/text2image_ms/__main__.py | 116 +++++++++++- src/AGISwarm/text2image_ms/app/gui.html | 56 ++++++ src/AGISwarm/text2image_ms/app/script.js | 129 +++++++++++++ src/AGISwarm/text2image_ms/app/styles.css | 127 +++++++++++++ .../text2image_ms/diffusion_pipeline.py | 175 ++++++++++++++++++ src/AGISwarm/text2image_ms/typing.py | 36 ++++ src/AGISwarm/text2image_ms/utils.py | 61 ++++++ 12 files changed, 724 insertions(+), 9 deletions(-) create mode 100644 .isort.cfg create mode 100644 conf/config.yaml create mode 100644 pyrightconfig.json create mode 100644 src/AGISwarm/text2image_ms/app/gui.html create mode 100644 src/AGISwarm/text2image_ms/app/script.js create mode 100644 src/AGISwarm/text2image_ms/app/styles.css create mode 100644 src/AGISwarm/text2image_ms/diffusion_pipeline.py create mode 100644 src/AGISwarm/text2image_ms/typing.py create mode 100644 src/AGISwarm/text2image_ms/utils.py diff --git a/.gitignore b/.gitignore index 2d60ce0..5a172d0 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,5 @@ __pycache__/ .cache # due to using vscode -.vscode/ \ No newline at end of file +.vscode/ +outputs/ \ No newline at end of file diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..6860bdb --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,2 @@ +[settings] +profile = black \ No newline at end of file diff --git a/conf/config.yaml b/conf/config.yaml new file mode 100644 index 0000000..3ec0df5 --- /dev/null +++ b/conf/config.yaml @@ -0,0 +1,10 @@ +model: "runwayml/stable-diffusion-v1-5" +dtype: "float32" +device: "cuda:0" +safety_checker: +requires_safety_checker: True +low_cpu_mem_usage: False + +hydra: + job: + chdir: false \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4f95d99..0175518 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,13 +8,20 @@ dynamic = ["version"] description = "Python template project" readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.10" license = { file = "LICENSE.txt" } keywords = ["sample", "setuptools", "development"] classifiers = [ "Programming Language :: Python :: 3", ] -dependencies = ['numpy'] +dependencies = [ + 'numpy<2.0.0', + "flask", + "pillow", + "diffusers", + "torch", + "deepspeed" +] [project.optional-dependencies] test = ['pytest'] analyze = ['pyright', 'pylint', 'bandit', 'black', 'isort'] diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..0a65641 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,7 @@ +{ + "pythonVersion": "3.13", // Specify the Python version to use + "pythonPlatform": "Linux", // Specify the platform (Windows, Linux, or macOS) + "typeCheckingMode": "basic", // Set the type checking mode (basic or strict) + "reportMissingTypeStubs": false, // Report missing type stubs for imported modules + "reportPrivateImportUsage": false, // Report usage of private imports +} \ No newline at end of file diff --git a/src/AGISwarm/text2image_ms/__main__.py b/src/AGISwarm/text2image_ms/__main__.py index 826ad60..ea49858 100644 --- a/src/AGISwarm/text2image_ms/__main__.py +++ b/src/AGISwarm/text2image_ms/__main__.py @@ -1,12 +1,116 @@ -"""Main.py""" +""" +This module is the entry point for the text2image_ms service. -from . import __version__ +""" +import asyncio +import base64 +import logging +import traceback +import uuid +from io import BytesIO +from pathlib import Path -def main(): - """Main function""" - print(__version__) +import hydra +import uvicorn +from fastapi import Body, FastAPI, WebSocket +from fastapi.responses import FileResponse, HTMLResponse +from fastapi.staticfiles import StaticFiles +from hydra.core.config_store import ConfigStore +from jinja2 import Environment, FileSystemLoader +from PIL import Image + +from .diffusion_pipeline import Text2ImagePipeline +from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig + + +class Text2ImageApp: + """ + A class to represent the Text2Image service. + """ + + def __init__(self, config: Text2ImagePipelineConfig): + self.app = FastAPI(debug=True) + self.text2image_pipeline = Text2ImagePipeline(config) + self.setup_routes() + + def setup_routes(self): + """ + Set up the routes for the Text2Image service. + """ + self.app.get("/", response_class=HTMLResponse)(self.gui) + self.app.mount( + "/static", + StaticFiles(directory=Path(__file__).parent / "app", html=True), + name="static", + ) + self.app.add_websocket_route("/ws", self.generate) + + async def generate(self, websocket: WebSocket): + """ + Generate an image from a text prompt using the Text2Image pipeline. + """ + + await websocket.accept() + + try: + while True: + await asyncio.sleep(0.1) + data = await websocket.receive_text() + print (data) + gen_config = Text2ImageGenerationConfig.model_validate_json(data) + request_id = str(uuid.uuid4()) + async for step_info in self.text2image_pipeline.generate(request_id, gen_config): + if step_info['type'] == 'waiting': + await websocket.send_json(step_info) + continue + latents = step_info['image'] + + # Конвертируем латенты в base64 + image_io = BytesIO() + latents.save(image_io, 'PNG') + dataurl = 'data:image/png;base64,' + base64.b64encode(image_io.getvalue()).decode('ascii') + # Отправляем инфу о прогрессе и латенты + await websocket.send_json({ + "type": "generation_step", + "step": step_info['step'], + "total_steps": step_info['total_steps'], + "latents": dataurl, + "shape": latents.size + }) + + await websocket.send_json({ + "type": "generation_complete" + }) + except Exception as e: + logging.error(e) + traceback.print_exc() + await websocket.send_json( + { + "type": "error", + "message": str(e), + } + ) + finally: + await websocket.close() + + def gui(self): + """ + Get the GUI for the Text2Image service. + """ + print("GUI") + path = Path(__file__).parent / "app" / "gui.html" + return FileResponse(path) + + +@hydra.main(config_name="config") +def main(config: Text2ImagePipelineConfig): + """ + The main function for the Text2Image service. + """ + text2image_app = Text2ImageApp(config) + uvicorn.run(text2image_app.app, host="localhost", port=8002, log_level="debug") if __name__ == "__main__": - main() + main() # pylint: disable=no-value-for-parameter diff --git a/src/AGISwarm/text2image_ms/app/gui.html b/src/AGISwarm/text2image_ms/app/gui.html new file mode 100644 index 0000000..4d543ea --- /dev/null +++ b/src/AGISwarm/text2image_ms/app/gui.html @@ -0,0 +1,56 @@ + + + + + + Stable Diffusion Streaming Interface + + + + + +
+
+
+

Stable Diffusion Streaming

+ +
+
+
+ + +
+ +
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ +
+
+ + + + \ No newline at end of file diff --git a/src/AGISwarm/text2image_ms/app/script.js b/src/AGISwarm/text2image_ms/app/script.js new file mode 100644 index 0000000..85e5c37 --- /dev/null +++ b/src/AGISwarm/text2image_ms/app/script.js @@ -0,0 +1,129 @@ +let ws = new WebSocket(WEBSOCKET_URL); +let currentRequestID = ''; + +ws.onopen = function () { + console.log("WebSocket connection established"); +}; + +ws.onmessage = function (event) { + const data = JSON.parse(event.data); + console.log('Message received:', data); + if (data["type"] == "generation_complete") { + enableGenerateButton(); + return; + }; + if (data["type"] == "waiting") { + msg = data["msg"]; + console.log(msg); + return; + } + + imgElement = document.getElementById('image-output'); + let img = imgElement.querySelector('img'); + if (!img) { + img = document.createElement('img'); + imgElement.appendChild(img); + } + img.src = data['latents']; + console.log('Image updated:', data); +}; + +ws.onclose = function (event) { + console.log("WebSocket connection closed"); +}; + +function decodeBase64Image(base64String) { + // Remove data URL prefix if present + const base64Data = base64String.replace(/^data:image\/\w+;base64,/, ''); + + // Decode base64 + const binaryString = atob(base64Data); + + // Create Uint8Array + const uint8Array = new Uint8Array(binaryString.length); + for (let i = 0; i < binaryString.length; i++) { + uint8Array[i] = binaryString.charCodeAt(i); + } + + // Create Blob + return new Blob([uint8Array], { type: 'image/png' }); +}; + + +function sendMessage() { + const prompt = document.getElementById('prompt').value; + const negative_prompt = document.getElementById('negative_prompt').value; + const num_inference_steps = document.getElementById('num_inference_steps').value; + const guidance_scale = document.getElementById('guidance_scale').value; + const width = document.getElementById('width').value; + const height = document.getElementById('height').value; + const seed = document.getElementById('seed').value; + + const message = { + prompt: prompt, + negative_prompt: negative_prompt, + num_inference_steps: parseInt(num_inference_steps), + guidance_scale: parseFloat(guidance_scale), + width: parseInt(width), + height: parseInt(height), + seed: parseInt(seed) + }; + + ws.send(JSON.stringify(message)); + console.log('Message sent:', message); + disableGenerateButton(); +}; + +function updateStatus(step, total_steps) { + const statusElement = document.getElementById('status'); + if (!statusElement) { + const statusElement = document.createElement('div'); + statusElement.id = 'status'; + document.getElementById('image-output').appendChild(statusElement); + } + statusElement.textContent = `Generating: ${step}/${total_steps}`; +} + + +function updateImage(base64Image) { + const img = document.getElementById('generated-image'); + if (!img) { + const img = document.createElement('img'); + img.id = 'generated-image'; + document.body.appendChild(img); + } + img.src = `data:image/png;base64,${base64Image}`; +} + +function disableGenerateButton() { + const button = document.getElementById('send-btn'); + button.disabled = true; + button.textContent = 'Generating...'; +} + +function enableGenerateButton() { + const button = document.getElementById('send-btn'); + button.disabled = false; + button.textContent = 'Generate'; +} + +function resetForm() { + document.getElementById('num_inference_steps').value = '50'; + document.getElementById('guidance_scale').value = '7.5'; + document.getElementById('width').value = '512'; + document.getElementById('height').value = '512'; + document.getElementById('seed').value = '-1'; +} + +const menuToggle = document.getElementById('menu-toggle'); +const configContainer = document.querySelector('.config-container'); + +menuToggle.addEventListener('click', () => { + configContainer.classList.toggle('show'); +}); + +document.addEventListener('click', (event) => { + if (!configContainer.contains(event.target) && !menuToggle.contains(event.target)) { + configContainer.classList.remove('show'); + } +}); \ No newline at end of file diff --git a/src/AGISwarm/text2image_ms/app/styles.css b/src/AGISwarm/text2image_ms/app/styles.css new file mode 100644 index 0000000..b9dbce4 --- /dev/null +++ b/src/AGISwarm/text2image_ms/app/styles.css @@ -0,0 +1,127 @@ +body, html { + font-family: Arial, sans-serif; + margin: 0; + padding: 0; + background-color: #121212; + color: #ffffff; +} + +.container { + display: flex; + max-width: 100vw; + height: 100vh; + padding: 20px; +} + +.generation-container { + flex-grow: 1; + display: flex; + flex-direction: column; + margin-right: 20px; + max-width: 70%; +} + +.config-container { + background-color: #1e1e1e; + padding: 20px; + border-radius: 5px; + width: 30%; +} + +.form-group { + margin-bottom: 15px; +} + +.form-group label { + display: block; + margin-bottom: 5px; +} + +.form-control { + width: 100%; + padding: 8px; + background-color: #2d2d2d; + color: #ffffff; + border: none; + border-radius: 4px; +} + +.btn { + padding: 8px 15px; +} + +.image-output { + flex-grow: 1; + background-color: #1e1e1e; + border-radius: 5px; + margin-bottom: 15px; + display: flex; + justify-content: center; + align-items: center; + overflow: hidden; +} + +.image-output img { + max-width: 100%; + max-height: 100%; + object-fit: contain; +} + +.input-container { + display: flex; + margin-bottom: 15px; +} + +.input-container .form-control { + flex-grow: 1; + margin-right: 10px; +} + +#send-btn { + background-color: #363d46; + border-color: #363d46; +} + +.generation-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 15px; +} + +#title { + font-size: 24px; + margin: 0; +} + +.menu-toggle { + display: none; +} + +@media (max-width: 991px) { + .container { + flex-direction: column; + } + + .generation-container { + max-width: 100%; + margin-right: 0; + margin-bottom: 20px; + } + + .config-container { + width: 100%; + } + + .menu-toggle { + display: block; + } + + .config-container { + display: none; + } + + .config-container.show { + display: block; + } +} \ No newline at end of file diff --git a/src/AGISwarm/text2image_ms/diffusion_pipeline.py b/src/AGISwarm/text2image_ms/diffusion_pipeline.py new file mode 100644 index 0000000..1d2d8fb --- /dev/null +++ b/src/AGISwarm/text2image_ms/diffusion_pipeline.py @@ -0,0 +1,175 @@ +""" +This module contains the Stable Diffusion Pipeline class +that is used to generate images from text prompts using the Stable Diffusion model. +""" + +import asyncio +import threading +from functools import partial +from typing import Optional, Union + +import numpy as np +import torch +from diffusers import DiffusionPipeline, StableDiffusionPipeline +from diffusers.callbacks import PipelineCallback +from PIL import Image + +from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig +from .utils import generation_request_queued_func + + +class DiffusionIteratorStreamer: + + def __init__(self, timeout: Optional[Union[float, int]] = None): + self.latents_stack = [] + self.stop_signal: Optional[str] = None + self.timeout = timeout + self.current_step = 0 + self.total_steps = 0 + self.stop = False + + def put(self, latents: torch.Tensor): + """Метод для добавления латентов в очередь""" + self.latents_stack.append(latents.cpu().numpy()) + + def end(self): + """Метод для сигнализации окончания генерации""" + self.stop = True + + def __aiter__(self) -> "DiffusionIteratorStreamer": + return self + + async def __anext__(self) -> torch.Tensor | str: + while len(self.latents_stack) == 0: + await asyncio.sleep(0.1) + latents = self.latents_stack.pop() + if self.stop: + raise StopAsyncIteration() + return latents + + def callback( + self, + pipeline: DiffusionPipeline, + step: int, + timestep: int, + callback_kwargs: dict, + ): + """Callback для StableDiffusionPipeline""" + self.current_step = step + self.put(callback_kwargs["latents"]) + return {"latents": callback_kwargs["latents"]} + + def stream( + self, + pipe: StableDiffusionPipeline, + prompt: str, + negative_prompt: str, + num_inference_steps: int, + guidance_scale: float, + seed: int, + width: int, + height: int, + ): + """Method to stream the diffusion pipeline""" + self.total_steps = num_inference_steps + + def run_pipeline(): + pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + width=width, + height=height, + callback_on_step_end=self.callback, # type: ignore + ) + self.end() + + thread = threading.Thread(target=run_pipeline) + thread.start() + return self + + +class Text2ImagePipeline: + """ + A class to generate images from text prompts using the Stable Diffusion model. + + Args: + config (Text2ImagePipelineConfig): The configuration for the Diffusion Pipeline initialization. + - model (str): The model to use for generating the image. + - dtype (str): The data type to use for the model. + - device (str): The device to run the model on. + - safety_checker (str | None): The safety checker to use for the model. + - requires_safety_checker (bool): Whether the model requires a safety checker. + - low_cpu_mem_usage (bool): Whether to use low CPU memory usage. + """ + + def __init__(self, config: Text2ImagePipelineConfig): + self.config = config + self.pipeline = StableDiffusionPipeline.from_pretrained( + config.model, + torch_dtype=getattr(torch, config.dtype), + safety_checker=config.safety_checker, + requires_safety_checker=config.requires_safety_checker, + low_cpu_mem_usage=config.low_cpu_mem_usage, + ).to(config.device) + + self.pipeline.vae.enable_tiling() + self.pipeline.vae.enable_slicing() + # self.pipeline.enable_sequential_cpu_offload() + + @partial(generation_request_queued_func, wait_time=0.1) + async def generate( + self, + request_id: str, + gen_config: Text2ImageGenerationConfig + ): + """ + Generate an image from a text prompt using the Text2Image pipeline. + + Args: + gen_config (Text2ImageGenerationConfig): The configuration for the Text2Image Pipeline generation. + - prompt (str): The text prompt to generate the image from. + - negative_prompt (str): The negative text prompt to generate the image from. + - num_inference_steps (int): The number of inference steps to run. + - guidance_scale (float): The guidance scale to use for the model. + - seed (int): The seed to use for the model. + - width (int): The width of the image to generate. + - height (int): The height of the image to generate. + + Yields: + dict: A dictionary containing the step information for the generation. + - step (int): The current step of the generation. + - total_steps (int): The total number of steps for the generation. + - image (PIL.Image): The generated image + """ + streamer = DiffusionIteratorStreamer() + streamer.stream( + self.pipeline, + prompt=gen_config.prompt, + negative_prompt=gen_config.negative_prompt, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + seed=gen_config.seed, + width=gen_config.width, + height=gen_config.height, + ) + + async for latents in streamer: + if latents is None: + await asyncio.sleep(0.1) + continue + + latents = torch.tensor(latents, device=self.config.device) + with torch.no_grad(): + image = self.pipeline.decode_latents(latents)[0] + image = Image.fromarray((image * 255).astype(np.uint8)) + await asyncio.sleep(0.1) + yield { + "type": "generation_step", + "request_id": request_id, + "step": streamer.current_step, + "total_steps": streamer.total_steps, + "image": image, + } diff --git a/src/AGISwarm/text2image_ms/typing.py b/src/AGISwarm/text2image_ms/typing.py new file mode 100644 index 0000000..674c035 --- /dev/null +++ b/src/AGISwarm/text2image_ms/typing.py @@ -0,0 +1,36 @@ +""" +This module contains the typing classes for the Text2Image Pipeline. + +""" + +from dataclasses import dataclass + +from pydantic import BaseModel + + +@dataclass +class Text2ImagePipelineConfig(BaseModel): + """ + A class to hold the configuration for the Diffusion Pipeline initialization. + """ + + model: str + dtype: str + device: str + safety_checker: str | None + requires_safety_checker: bool + low_cpu_mem_usage: bool + + +class Text2ImageGenerationConfig(BaseModel): + """ + A class to hold the configuration for the Text2Image Pipeline generation. + """ + + prompt: str + negative_prompt: str + num_inference_steps: int + guidance_scale: float + seed: int + width: int + height: int \ No newline at end of file diff --git a/src/AGISwarm/text2image_ms/utils.py b/src/AGISwarm/text2image_ms/utils.py new file mode 100644 index 0000000..93874e6 --- /dev/null +++ b/src/AGISwarm/text2image_ms/utils.py @@ -0,0 +1,61 @@ +"""Utility functions for LLM engines""" + +import asyncio +import threading +from abc import abstractmethod +from typing import Dict, Generic, List, Protocol, TypeVar, cast, runtime_checkable + +from pydantic import BaseModel + +__ABORT_EVENTS = {} +__QUEUE = [] + + +def abort_generation_request(request_id: str): + """Abort generation request""" + if request_id in __ABORT_EVENTS: + __ABORT_EVENTS[request_id].set() + + +def generation_request_queued_func(func, wait_time=0.2): + """Decorator for generation requests""" + + def abort_response(request_id: str): + return { + "request_id": request_id, + "type": "abort", + "msg": "Generation aborted.", + } + + def waiting_response(request_id: str): + """Waiting response""" + return { + "request_id": request_id, + "type": "waiting", + "msg": f"Waiting for {__QUEUE.index(request_id)} requests to finish...\n", + } + + async def wrapper(*args, **kwargs): + request_id = args[1] + __ABORT_EVENTS[request_id] = threading.Event() + __QUEUE.append(request_id) + try: + while __QUEUE[0] != request_id: + await asyncio.sleep(wait_time) + if __ABORT_EVENTS[request_id].is_set(): + yield abort_response(request_id) + return + yield waiting_response(request_id) + async for response in func(*args, **kwargs): + if __ABORT_EVENTS[request_id].is_set(): + yield abort_response(request_id) + return + yield response + except asyncio.CancelledError as e: + print(e) + finally: + __QUEUE.remove(request_id) + __ABORT_EVENTS[request_id].clear() + __ABORT_EVENTS.pop(request_id) + + return wrapper