diff --git a/.gitignore b/.gitignore
index 2d60ce0..5a172d0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -13,4 +13,5 @@ __pycache__/
.cache
# due to using vscode
-.vscode/
\ No newline at end of file
+.vscode/
+outputs/
\ No newline at end of file
diff --git a/.isort.cfg b/.isort.cfg
new file mode 100644
index 0000000..6860bdb
--- /dev/null
+++ b/.isort.cfg
@@ -0,0 +1,2 @@
+[settings]
+profile = black
\ No newline at end of file
diff --git a/conf/config.yaml b/conf/config.yaml
new file mode 100644
index 0000000..3ec0df5
--- /dev/null
+++ b/conf/config.yaml
@@ -0,0 +1,10 @@
+model: "runwayml/stable-diffusion-v1-5"
+dtype: "float32"
+device: "cuda:0"
+safety_checker:
+requires_safety_checker: True
+low_cpu_mem_usage: False
+
+hydra:
+ job:
+ chdir: false
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 4f95d99..0175518 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,13 +8,20 @@ dynamic = ["version"]
description = "Python template project"
readme = "README.md"
-requires-python = ">=3.11"
+requires-python = ">=3.10"
license = { file = "LICENSE.txt" }
keywords = ["sample", "setuptools", "development"]
classifiers = [
"Programming Language :: Python :: 3",
]
-dependencies = ['numpy']
+dependencies = [
+ 'numpy<2.0.0',
+ "flask",
+ "pillow",
+ "diffusers",
+ "torch",
+ "deepspeed"
+]
[project.optional-dependencies]
test = ['pytest']
analyze = ['pyright', 'pylint', 'bandit', 'black', 'isort']
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000..0a65641
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,7 @@
+{
+ "pythonVersion": "3.13", // Specify the Python version to use
+ "pythonPlatform": "Linux", // Specify the platform (Windows, Linux, or macOS)
+ "typeCheckingMode": "basic", // Set the type checking mode (basic or strict)
+ "reportMissingTypeStubs": false, // Report missing type stubs for imported modules
+ "reportPrivateImportUsage": false, // Report usage of private imports
+}
\ No newline at end of file
diff --git a/src/AGISwarm/text2image_ms/__main__.py b/src/AGISwarm/text2image_ms/__main__.py
index 826ad60..ea49858 100644
--- a/src/AGISwarm/text2image_ms/__main__.py
+++ b/src/AGISwarm/text2image_ms/__main__.py
@@ -1,12 +1,116 @@
-"""Main.py"""
+"""
+This module is the entry point for the text2image_ms service.
-from . import __version__
+"""
+import asyncio
+import base64
+import logging
+import traceback
+import uuid
+from io import BytesIO
+from pathlib import Path
-def main():
- """Main function"""
- print(__version__)
+import hydra
+import uvicorn
+from fastapi import Body, FastAPI, WebSocket
+from fastapi.responses import FileResponse, HTMLResponse
+from fastapi.staticfiles import StaticFiles
+from hydra.core.config_store import ConfigStore
+from jinja2 import Environment, FileSystemLoader
+from PIL import Image
+
+from .diffusion_pipeline import Text2ImagePipeline
+from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig
+
+
+class Text2ImageApp:
+ """
+ A class to represent the Text2Image service.
+ """
+
+ def __init__(self, config: Text2ImagePipelineConfig):
+ self.app = FastAPI(debug=True)
+ self.text2image_pipeline = Text2ImagePipeline(config)
+ self.setup_routes()
+
+ def setup_routes(self):
+ """
+ Set up the routes for the Text2Image service.
+ """
+ self.app.get("/", response_class=HTMLResponse)(self.gui)
+ self.app.mount(
+ "/static",
+ StaticFiles(directory=Path(__file__).parent / "app", html=True),
+ name="static",
+ )
+ self.app.add_websocket_route("/ws", self.generate)
+
+ async def generate(self, websocket: WebSocket):
+ """
+ Generate an image from a text prompt using the Text2Image pipeline.
+ """
+
+ await websocket.accept()
+
+ try:
+ while True:
+ await asyncio.sleep(0.1)
+ data = await websocket.receive_text()
+ print (data)
+ gen_config = Text2ImageGenerationConfig.model_validate_json(data)
+ request_id = str(uuid.uuid4())
+ async for step_info in self.text2image_pipeline.generate(request_id, gen_config):
+ if step_info['type'] == 'waiting':
+ await websocket.send_json(step_info)
+ continue
+ latents = step_info['image']
+
+ # Конвертируем латенты в base64
+ image_io = BytesIO()
+ latents.save(image_io, 'PNG')
+ dataurl = 'data:image/png;base64,' + base64.b64encode(image_io.getvalue()).decode('ascii')
+ # Отправляем инфу о прогрессе и латенты
+ await websocket.send_json({
+ "type": "generation_step",
+ "step": step_info['step'],
+ "total_steps": step_info['total_steps'],
+ "latents": dataurl,
+ "shape": latents.size
+ })
+
+ await websocket.send_json({
+ "type": "generation_complete"
+ })
+ except Exception as e:
+ logging.error(e)
+ traceback.print_exc()
+ await websocket.send_json(
+ {
+ "type": "error",
+ "message": str(e),
+ }
+ )
+ finally:
+ await websocket.close()
+
+ def gui(self):
+ """
+ Get the GUI for the Text2Image service.
+ """
+ print("GUI")
+ path = Path(__file__).parent / "app" / "gui.html"
+ return FileResponse(path)
+
+
+@hydra.main(config_name="config")
+def main(config: Text2ImagePipelineConfig):
+ """
+ The main function for the Text2Image service.
+ """
+ text2image_app = Text2ImageApp(config)
+ uvicorn.run(text2image_app.app, host="localhost", port=8002, log_level="debug")
if __name__ == "__main__":
- main()
+ main() # pylint: disable=no-value-for-parameter
diff --git a/src/AGISwarm/text2image_ms/app/gui.html b/src/AGISwarm/text2image_ms/app/gui.html
new file mode 100644
index 0000000..4d543ea
--- /dev/null
+++ b/src/AGISwarm/text2image_ms/app/gui.html
@@ -0,0 +1,56 @@
+
+
+
+
+
+ Stable Diffusion Streaming Interface
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/AGISwarm/text2image_ms/app/script.js b/src/AGISwarm/text2image_ms/app/script.js
new file mode 100644
index 0000000..85e5c37
--- /dev/null
+++ b/src/AGISwarm/text2image_ms/app/script.js
@@ -0,0 +1,129 @@
+let ws = new WebSocket(WEBSOCKET_URL);
+let currentRequestID = '';
+
+ws.onopen = function () {
+ console.log("WebSocket connection established");
+};
+
+ws.onmessage = function (event) {
+ const data = JSON.parse(event.data);
+ console.log('Message received:', data);
+ if (data["type"] == "generation_complete") {
+ enableGenerateButton();
+ return;
+ };
+ if (data["type"] == "waiting") {
+ msg = data["msg"];
+ console.log(msg);
+ return;
+ }
+
+ imgElement = document.getElementById('image-output');
+ let img = imgElement.querySelector('img');
+ if (!img) {
+ img = document.createElement('img');
+ imgElement.appendChild(img);
+ }
+ img.src = data['latents'];
+ console.log('Image updated:', data);
+};
+
+ws.onclose = function (event) {
+ console.log("WebSocket connection closed");
+};
+
+function decodeBase64Image(base64String) {
+ // Remove data URL prefix if present
+ const base64Data = base64String.replace(/^data:image\/\w+;base64,/, '');
+
+ // Decode base64
+ const binaryString = atob(base64Data);
+
+ // Create Uint8Array
+ const uint8Array = new Uint8Array(binaryString.length);
+ for (let i = 0; i < binaryString.length; i++) {
+ uint8Array[i] = binaryString.charCodeAt(i);
+ }
+
+ // Create Blob
+ return new Blob([uint8Array], { type: 'image/png' });
+};
+
+
+function sendMessage() {
+ const prompt = document.getElementById('prompt').value;
+ const negative_prompt = document.getElementById('negative_prompt').value;
+ const num_inference_steps = document.getElementById('num_inference_steps').value;
+ const guidance_scale = document.getElementById('guidance_scale').value;
+ const width = document.getElementById('width').value;
+ const height = document.getElementById('height').value;
+ const seed = document.getElementById('seed').value;
+
+ const message = {
+ prompt: prompt,
+ negative_prompt: negative_prompt,
+ num_inference_steps: parseInt(num_inference_steps),
+ guidance_scale: parseFloat(guidance_scale),
+ width: parseInt(width),
+ height: parseInt(height),
+ seed: parseInt(seed)
+ };
+
+ ws.send(JSON.stringify(message));
+ console.log('Message sent:', message);
+ disableGenerateButton();
+};
+
+function updateStatus(step, total_steps) {
+ const statusElement = document.getElementById('status');
+ if (!statusElement) {
+ const statusElement = document.createElement('div');
+ statusElement.id = 'status';
+ document.getElementById('image-output').appendChild(statusElement);
+ }
+ statusElement.textContent = `Generating: ${step}/${total_steps}`;
+}
+
+
+function updateImage(base64Image) {
+ const img = document.getElementById('generated-image');
+ if (!img) {
+ const img = document.createElement('img');
+ img.id = 'generated-image';
+ document.body.appendChild(img);
+ }
+ img.src = `data:image/png;base64,${base64Image}`;
+}
+
+function disableGenerateButton() {
+ const button = document.getElementById('send-btn');
+ button.disabled = true;
+ button.textContent = 'Generating...';
+}
+
+function enableGenerateButton() {
+ const button = document.getElementById('send-btn');
+ button.disabled = false;
+ button.textContent = 'Generate';
+}
+
+function resetForm() {
+ document.getElementById('num_inference_steps').value = '50';
+ document.getElementById('guidance_scale').value = '7.5';
+ document.getElementById('width').value = '512';
+ document.getElementById('height').value = '512';
+ document.getElementById('seed').value = '-1';
+}
+
+const menuToggle = document.getElementById('menu-toggle');
+const configContainer = document.querySelector('.config-container');
+
+menuToggle.addEventListener('click', () => {
+ configContainer.classList.toggle('show');
+});
+
+document.addEventListener('click', (event) => {
+ if (!configContainer.contains(event.target) && !menuToggle.contains(event.target)) {
+ configContainer.classList.remove('show');
+ }
+});
\ No newline at end of file
diff --git a/src/AGISwarm/text2image_ms/app/styles.css b/src/AGISwarm/text2image_ms/app/styles.css
new file mode 100644
index 0000000..b9dbce4
--- /dev/null
+++ b/src/AGISwarm/text2image_ms/app/styles.css
@@ -0,0 +1,127 @@
+body, html {
+ font-family: Arial, sans-serif;
+ margin: 0;
+ padding: 0;
+ background-color: #121212;
+ color: #ffffff;
+}
+
+.container {
+ display: flex;
+ max-width: 100vw;
+ height: 100vh;
+ padding: 20px;
+}
+
+.generation-container {
+ flex-grow: 1;
+ display: flex;
+ flex-direction: column;
+ margin-right: 20px;
+ max-width: 70%;
+}
+
+.config-container {
+ background-color: #1e1e1e;
+ padding: 20px;
+ border-radius: 5px;
+ width: 30%;
+}
+
+.form-group {
+ margin-bottom: 15px;
+}
+
+.form-group label {
+ display: block;
+ margin-bottom: 5px;
+}
+
+.form-control {
+ width: 100%;
+ padding: 8px;
+ background-color: #2d2d2d;
+ color: #ffffff;
+ border: none;
+ border-radius: 4px;
+}
+
+.btn {
+ padding: 8px 15px;
+}
+
+.image-output {
+ flex-grow: 1;
+ background-color: #1e1e1e;
+ border-radius: 5px;
+ margin-bottom: 15px;
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ overflow: hidden;
+}
+
+.image-output img {
+ max-width: 100%;
+ max-height: 100%;
+ object-fit: contain;
+}
+
+.input-container {
+ display: flex;
+ margin-bottom: 15px;
+}
+
+.input-container .form-control {
+ flex-grow: 1;
+ margin-right: 10px;
+}
+
+#send-btn {
+ background-color: #363d46;
+ border-color: #363d46;
+}
+
+.generation-header {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ margin-bottom: 15px;
+}
+
+#title {
+ font-size: 24px;
+ margin: 0;
+}
+
+.menu-toggle {
+ display: none;
+}
+
+@media (max-width: 991px) {
+ .container {
+ flex-direction: column;
+ }
+
+ .generation-container {
+ max-width: 100%;
+ margin-right: 0;
+ margin-bottom: 20px;
+ }
+
+ .config-container {
+ width: 100%;
+ }
+
+ .menu-toggle {
+ display: block;
+ }
+
+ .config-container {
+ display: none;
+ }
+
+ .config-container.show {
+ display: block;
+ }
+}
\ No newline at end of file
diff --git a/src/AGISwarm/text2image_ms/diffusion_pipeline.py b/src/AGISwarm/text2image_ms/diffusion_pipeline.py
new file mode 100644
index 0000000..1d2d8fb
--- /dev/null
+++ b/src/AGISwarm/text2image_ms/diffusion_pipeline.py
@@ -0,0 +1,175 @@
+"""
+This module contains the Stable Diffusion Pipeline class
+that is used to generate images from text prompts using the Stable Diffusion model.
+"""
+
+import asyncio
+import threading
+from functools import partial
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from diffusers import DiffusionPipeline, StableDiffusionPipeline
+from diffusers.callbacks import PipelineCallback
+from PIL import Image
+
+from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig
+from .utils import generation_request_queued_func
+
+
+class DiffusionIteratorStreamer:
+
+ def __init__(self, timeout: Optional[Union[float, int]] = None):
+ self.latents_stack = []
+ self.stop_signal: Optional[str] = None
+ self.timeout = timeout
+ self.current_step = 0
+ self.total_steps = 0
+ self.stop = False
+
+ def put(self, latents: torch.Tensor):
+ """Метод для добавления латентов в очередь"""
+ self.latents_stack.append(latents.cpu().numpy())
+
+ def end(self):
+ """Метод для сигнализации окончания генерации"""
+ self.stop = True
+
+ def __aiter__(self) -> "DiffusionIteratorStreamer":
+ return self
+
+ async def __anext__(self) -> torch.Tensor | str:
+ while len(self.latents_stack) == 0:
+ await asyncio.sleep(0.1)
+ latents = self.latents_stack.pop()
+ if self.stop:
+ raise StopAsyncIteration()
+ return latents
+
+ def callback(
+ self,
+ pipeline: DiffusionPipeline,
+ step: int,
+ timestep: int,
+ callback_kwargs: dict,
+ ):
+ """Callback для StableDiffusionPipeline"""
+ self.current_step = step
+ self.put(callback_kwargs["latents"])
+ return {"latents": callback_kwargs["latents"]}
+
+ def stream(
+ self,
+ pipe: StableDiffusionPipeline,
+ prompt: str,
+ negative_prompt: str,
+ num_inference_steps: int,
+ guidance_scale: float,
+ seed: int,
+ width: int,
+ height: int,
+ ):
+ """Method to stream the diffusion pipeline"""
+ self.total_steps = num_inference_steps
+
+ def run_pipeline():
+ pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ seed=seed,
+ width=width,
+ height=height,
+ callback_on_step_end=self.callback, # type: ignore
+ )
+ self.end()
+
+ thread = threading.Thread(target=run_pipeline)
+ thread.start()
+ return self
+
+
+class Text2ImagePipeline:
+ """
+ A class to generate images from text prompts using the Stable Diffusion model.
+
+ Args:
+ config (Text2ImagePipelineConfig): The configuration for the Diffusion Pipeline initialization.
+ - model (str): The model to use for generating the image.
+ - dtype (str): The data type to use for the model.
+ - device (str): The device to run the model on.
+ - safety_checker (str | None): The safety checker to use for the model.
+ - requires_safety_checker (bool): Whether the model requires a safety checker.
+ - low_cpu_mem_usage (bool): Whether to use low CPU memory usage.
+ """
+
+ def __init__(self, config: Text2ImagePipelineConfig):
+ self.config = config
+ self.pipeline = StableDiffusionPipeline.from_pretrained(
+ config.model,
+ torch_dtype=getattr(torch, config.dtype),
+ safety_checker=config.safety_checker,
+ requires_safety_checker=config.requires_safety_checker,
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
+ ).to(config.device)
+
+ self.pipeline.vae.enable_tiling()
+ self.pipeline.vae.enable_slicing()
+ # self.pipeline.enable_sequential_cpu_offload()
+
+ @partial(generation_request_queued_func, wait_time=0.1)
+ async def generate(
+ self,
+ request_id: str,
+ gen_config: Text2ImageGenerationConfig
+ ):
+ """
+ Generate an image from a text prompt using the Text2Image pipeline.
+
+ Args:
+ gen_config (Text2ImageGenerationConfig): The configuration for the Text2Image Pipeline generation.
+ - prompt (str): The text prompt to generate the image from.
+ - negative_prompt (str): The negative text prompt to generate the image from.
+ - num_inference_steps (int): The number of inference steps to run.
+ - guidance_scale (float): The guidance scale to use for the model.
+ - seed (int): The seed to use for the model.
+ - width (int): The width of the image to generate.
+ - height (int): The height of the image to generate.
+
+ Yields:
+ dict: A dictionary containing the step information for the generation.
+ - step (int): The current step of the generation.
+ - total_steps (int): The total number of steps for the generation.
+ - image (PIL.Image): The generated image
+ """
+ streamer = DiffusionIteratorStreamer()
+ streamer.stream(
+ self.pipeline,
+ prompt=gen_config.prompt,
+ negative_prompt=gen_config.negative_prompt,
+ num_inference_steps=gen_config.num_inference_steps,
+ guidance_scale=gen_config.guidance_scale,
+ seed=gen_config.seed,
+ width=gen_config.width,
+ height=gen_config.height,
+ )
+
+ async for latents in streamer:
+ if latents is None:
+ await asyncio.sleep(0.1)
+ continue
+
+ latents = torch.tensor(latents, device=self.config.device)
+ with torch.no_grad():
+ image = self.pipeline.decode_latents(latents)[0]
+ image = Image.fromarray((image * 255).astype(np.uint8))
+ await asyncio.sleep(0.1)
+ yield {
+ "type": "generation_step",
+ "request_id": request_id,
+ "step": streamer.current_step,
+ "total_steps": streamer.total_steps,
+ "image": image,
+ }
diff --git a/src/AGISwarm/text2image_ms/typing.py b/src/AGISwarm/text2image_ms/typing.py
new file mode 100644
index 0000000..674c035
--- /dev/null
+++ b/src/AGISwarm/text2image_ms/typing.py
@@ -0,0 +1,36 @@
+"""
+This module contains the typing classes for the Text2Image Pipeline.
+
+"""
+
+from dataclasses import dataclass
+
+from pydantic import BaseModel
+
+
+@dataclass
+class Text2ImagePipelineConfig(BaseModel):
+ """
+ A class to hold the configuration for the Diffusion Pipeline initialization.
+ """
+
+ model: str
+ dtype: str
+ device: str
+ safety_checker: str | None
+ requires_safety_checker: bool
+ low_cpu_mem_usage: bool
+
+
+class Text2ImageGenerationConfig(BaseModel):
+ """
+ A class to hold the configuration for the Text2Image Pipeline generation.
+ """
+
+ prompt: str
+ negative_prompt: str
+ num_inference_steps: int
+ guidance_scale: float
+ seed: int
+ width: int
+ height: int
\ No newline at end of file
diff --git a/src/AGISwarm/text2image_ms/utils.py b/src/AGISwarm/text2image_ms/utils.py
new file mode 100644
index 0000000..93874e6
--- /dev/null
+++ b/src/AGISwarm/text2image_ms/utils.py
@@ -0,0 +1,61 @@
+"""Utility functions for LLM engines"""
+
+import asyncio
+import threading
+from abc import abstractmethod
+from typing import Dict, Generic, List, Protocol, TypeVar, cast, runtime_checkable
+
+from pydantic import BaseModel
+
+__ABORT_EVENTS = {}
+__QUEUE = []
+
+
+def abort_generation_request(request_id: str):
+ """Abort generation request"""
+ if request_id in __ABORT_EVENTS:
+ __ABORT_EVENTS[request_id].set()
+
+
+def generation_request_queued_func(func, wait_time=0.2):
+ """Decorator for generation requests"""
+
+ def abort_response(request_id: str):
+ return {
+ "request_id": request_id,
+ "type": "abort",
+ "msg": "Generation aborted.",
+ }
+
+ def waiting_response(request_id: str):
+ """Waiting response"""
+ return {
+ "request_id": request_id,
+ "type": "waiting",
+ "msg": f"Waiting for {__QUEUE.index(request_id)} requests to finish...\n",
+ }
+
+ async def wrapper(*args, **kwargs):
+ request_id = args[1]
+ __ABORT_EVENTS[request_id] = threading.Event()
+ __QUEUE.append(request_id)
+ try:
+ while __QUEUE[0] != request_id:
+ await asyncio.sleep(wait_time)
+ if __ABORT_EVENTS[request_id].is_set():
+ yield abort_response(request_id)
+ return
+ yield waiting_response(request_id)
+ async for response in func(*args, **kwargs):
+ if __ABORT_EVENTS[request_id].is_set():
+ yield abort_response(request_id)
+ return
+ yield response
+ except asyncio.CancelledError as e:
+ print(e)
+ finally:
+ __QUEUE.remove(request_id)
+ __ABORT_EVENTS[request_id].clear()
+ __ABORT_EVENTS.pop(request_id)
+
+ return wrapper