Skip to content

Commit

Permalink
all works
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Sep 3, 2024
1 parent 3e3ac09 commit 2c76312
Show file tree
Hide file tree
Showing 12 changed files with 724 additions and 9 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ __pycache__/
.cache

# due to using vscode
.vscode/
.vscode/
outputs/
2 changes: 2 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[settings]
profile = black
10 changes: 10 additions & 0 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
model: "runwayml/stable-diffusion-v1-5"
dtype: "float32"
device: "cuda:0"
safety_checker:
requires_safety_checker: True
low_cpu_mem_usage: False

hydra:
job:
chdir: false
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@ dynamic = ["version"]
description = "Python template project"
readme = "README.md"

requires-python = ">=3.11"
requires-python = ">=3.10"
license = { file = "LICENSE.txt" }
keywords = ["sample", "setuptools", "development"]
classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = ['numpy']
dependencies = [
'numpy<2.0.0',
"flask",
"pillow",
"diffusers",
"torch",
"deepspeed"
]
[project.optional-dependencies]
test = ['pytest']
analyze = ['pyright', 'pylint', 'bandit', 'black', 'isort']
Expand Down
7 changes: 7 additions & 0 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"pythonVersion": "3.13", // Specify the Python version to use
"pythonPlatform": "Linux", // Specify the platform (Windows, Linux, or macOS)
"typeCheckingMode": "basic", // Set the type checking mode (basic or strict)
"reportMissingTypeStubs": false, // Report missing type stubs for imported modules
"reportPrivateImportUsage": false, // Report usage of private imports
}
116 changes: 110 additions & 6 deletions src/AGISwarm/text2image_ms/__main__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,116 @@
"""Main.py"""
"""
This module is the entry point for the text2image_ms service.
from . import __version__
"""

import asyncio
import base64
import logging
import traceback
import uuid
from io import BytesIO
from pathlib import Path

def main():
"""Main function"""
print(__version__)
import hydra
import uvicorn
from fastapi import Body, FastAPI, WebSocket
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from hydra.core.config_store import ConfigStore
from jinja2 import Environment, FileSystemLoader
from PIL import Image

from .diffusion_pipeline import Text2ImagePipeline
from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig


class Text2ImageApp:
"""
A class to represent the Text2Image service.
"""

def __init__(self, config: Text2ImagePipelineConfig):
self.app = FastAPI(debug=True)
self.text2image_pipeline = Text2ImagePipeline(config)
self.setup_routes()

def setup_routes(self):
"""
Set up the routes for the Text2Image service.
"""
self.app.get("/", response_class=HTMLResponse)(self.gui)
self.app.mount(
"/static",
StaticFiles(directory=Path(__file__).parent / "app", html=True),
name="static",
)
self.app.add_websocket_route("/ws", self.generate)

async def generate(self, websocket: WebSocket):
"""
Generate an image from a text prompt using the Text2Image pipeline.
"""

await websocket.accept()

try:
while True:
await asyncio.sleep(0.1)
data = await websocket.receive_text()
print (data)
gen_config = Text2ImageGenerationConfig.model_validate_json(data)
request_id = str(uuid.uuid4())
async for step_info in self.text2image_pipeline.generate(request_id, gen_config):
if step_info['type'] == 'waiting':
await websocket.send_json(step_info)
continue
latents = step_info['image']

# Конвертируем латенты в base64
image_io = BytesIO()
latents.save(image_io, 'PNG')
dataurl = 'data:image/png;base64,' + base64.b64encode(image_io.getvalue()).decode('ascii')
# Отправляем инфу о прогрессе и латенты
await websocket.send_json({
"type": "generation_step",
"step": step_info['step'],
"total_steps": step_info['total_steps'],
"latents": dataurl,
"shape": latents.size
})

await websocket.send_json({
"type": "generation_complete"
})
except Exception as e:
logging.error(e)
traceback.print_exc()
await websocket.send_json(
{
"type": "error",
"message": str(e),
}
)
finally:
await websocket.close()

def gui(self):
"""
Get the GUI for the Text2Image service.
"""
print("GUI")
path = Path(__file__).parent / "app" / "gui.html"
return FileResponse(path)


@hydra.main(config_name="config")
def main(config: Text2ImagePipelineConfig):
"""
The main function for the Text2Image service.
"""
text2image_app = Text2ImageApp(config)
uvicorn.run(text2image_app.app, host="localhost", port=8002, log_level="debug")


if __name__ == "__main__":
main()
main() # pylint: disable=no-value-for-parameter
56 changes: 56 additions & 0 deletions src/AGISwarm/text2image_ms/app/gui.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Stable Diffusion Streaming Interface</title>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">
<link rel="stylesheet" href="/static/styles.css">
</head>
<body>
<div class="container">
<div class="generation-container">
<div class="generation-header">
<h1 id="title">Stable Diffusion Streaming</h1>
<button id="menu-toggle" class="btn btn-secondary menu-toggle"><i class="fas fa-bars"></i></button>
</div>
<div id="image-output" class="image-output"></div>
<div class="input-container">
<textarea id="prompt" class="form-control" placeholder="Enter your prompt" rows="2"></textarea>
<button onclick="sendMessage()" class="btn btn-primary" id="send-btn">Generate</button>
</div>
<textarea id="negative_prompt" class="form-control" placeholder="Enter your negative prompt" rows="2"></textarea>
</div>

<div class="config-container">
<div class="form-group">
<label for="num_inference_steps">Inference Steps:</label>
<input type="number" id="num_inference_steps" class="form-control" value="50">
</div>
<div class="form-group">
<label for="guidance_scale">Guidance Scale:</label>
<input type="number" id="guidance_scale" class="form-control" value="7.5" step="0.1">
</div>
<div class="form-group">
<label for="width">Width:</label>
<input type="number" id="width" class="form-control" value="512" step="8">
</div>
<div class="form-group">
<label for="height">Height:</label>
<input type="number" id="height" class="form-control" value="512" step="8">
</div>
<div class="form-group">
<label for="seed">Seed:</label>
<input type="number" id="seed" class="form-control" value="-1">
</div>
<button type="button" class="btn btn-secondary" onclick="resetForm()">Reset</button>
</div>
</div>
<script>
const WEBSOCKET_URL = "/ws";
const ABORT_URL = "/abort";
</script>
<script src="/static/script.js"></script>
</body>
</html>
129 changes: 129 additions & 0 deletions src/AGISwarm/text2image_ms/app/script.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
let ws = new WebSocket(WEBSOCKET_URL);
let currentRequestID = '';

ws.onopen = function () {
console.log("WebSocket connection established");
};

ws.onmessage = function (event) {
const data = JSON.parse(event.data);
console.log('Message received:', data);
if (data["type"] == "generation_complete") {
enableGenerateButton();
return;
};
if (data["type"] == "waiting") {
msg = data["msg"];
console.log(msg);
return;
}

imgElement = document.getElementById('image-output');
let img = imgElement.querySelector('img');
if (!img) {
img = document.createElement('img');
imgElement.appendChild(img);
}
img.src = data['latents'];
console.log('Image updated:', data);
};

ws.onclose = function (event) {
console.log("WebSocket connection closed");
};

function decodeBase64Image(base64String) {
// Remove data URL prefix if present
const base64Data = base64String.replace(/^data:image\/\w+;base64,/, '');

// Decode base64
const binaryString = atob(base64Data);

// Create Uint8Array
const uint8Array = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
uint8Array[i] = binaryString.charCodeAt(i);
}

// Create Blob
return new Blob([uint8Array], { type: 'image/png' });
};


function sendMessage() {
const prompt = document.getElementById('prompt').value;
const negative_prompt = document.getElementById('negative_prompt').value;
const num_inference_steps = document.getElementById('num_inference_steps').value;
const guidance_scale = document.getElementById('guidance_scale').value;
const width = document.getElementById('width').value;
const height = document.getElementById('height').value;
const seed = document.getElementById('seed').value;

const message = {
prompt: prompt,
negative_prompt: negative_prompt,
num_inference_steps: parseInt(num_inference_steps),
guidance_scale: parseFloat(guidance_scale),
width: parseInt(width),
height: parseInt(height),
seed: parseInt(seed)
};

ws.send(JSON.stringify(message));
console.log('Message sent:', message);
disableGenerateButton();
};

function updateStatus(step, total_steps) {
const statusElement = document.getElementById('status');
if (!statusElement) {
const statusElement = document.createElement('div');
statusElement.id = 'status';
document.getElementById('image-output').appendChild(statusElement);
}
statusElement.textContent = `Generating: ${step}/${total_steps}`;
}


function updateImage(base64Image) {
const img = document.getElementById('generated-image');
if (!img) {
const img = document.createElement('img');
img.id = 'generated-image';
document.body.appendChild(img);
}
img.src = `data:image/png;base64,${base64Image}`;
}

function disableGenerateButton() {
const button = document.getElementById('send-btn');
button.disabled = true;
button.textContent = 'Generating...';
}

function enableGenerateButton() {
const button = document.getElementById('send-btn');
button.disabled = false;
button.textContent = 'Generate';
}

function resetForm() {
document.getElementById('num_inference_steps').value = '50';
document.getElementById('guidance_scale').value = '7.5';
document.getElementById('width').value = '512';
document.getElementById('height').value = '512';
document.getElementById('seed').value = '-1';
}

const menuToggle = document.getElementById('menu-toggle');
const configContainer = document.querySelector('.config-container');

menuToggle.addEventListener('click', () => {
configContainer.classList.toggle('show');
});

document.addEventListener('click', (event) => {
if (!configContainer.contains(event.target) && !menuToggle.contains(event.target)) {
configContainer.classList.remove('show');
}
});
Loading

0 comments on commit 2c76312

Please sign in to comment.