generated from AGISwarm/python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3e3ac09
commit 2c76312
Showing
12 changed files
with
724 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ __pycache__/ | |
.cache | ||
|
||
# due to using vscode | ||
.vscode/ | ||
.vscode/ | ||
outputs/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[settings] | ||
profile = black |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
model: "runwayml/stable-diffusion-v1-5" | ||
dtype: "float32" | ||
device: "cuda:0" | ||
safety_checker: | ||
requires_safety_checker: True | ||
low_cpu_mem_usage: False | ||
|
||
hydra: | ||
job: | ||
chdir: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"pythonVersion": "3.13", // Specify the Python version to use | ||
"pythonPlatform": "Linux", // Specify the platform (Windows, Linux, or macOS) | ||
"typeCheckingMode": "basic", // Set the type checking mode (basic or strict) | ||
"reportMissingTypeStubs": false, // Report missing type stubs for imported modules | ||
"reportPrivateImportUsage": false, // Report usage of private imports | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,116 @@ | ||
"""Main.py""" | ||
""" | ||
This module is the entry point for the text2image_ms service. | ||
from . import __version__ | ||
""" | ||
|
||
import asyncio | ||
import base64 | ||
import logging | ||
import traceback | ||
import uuid | ||
from io import BytesIO | ||
from pathlib import Path | ||
|
||
def main(): | ||
"""Main function""" | ||
print(__version__) | ||
import hydra | ||
import uvicorn | ||
from fastapi import Body, FastAPI, WebSocket | ||
from fastapi.responses import FileResponse, HTMLResponse | ||
from fastapi.staticfiles import StaticFiles | ||
from hydra.core.config_store import ConfigStore | ||
from jinja2 import Environment, FileSystemLoader | ||
from PIL import Image | ||
|
||
from .diffusion_pipeline import Text2ImagePipeline | ||
from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig | ||
|
||
|
||
class Text2ImageApp: | ||
""" | ||
A class to represent the Text2Image service. | ||
""" | ||
|
||
def __init__(self, config: Text2ImagePipelineConfig): | ||
self.app = FastAPI(debug=True) | ||
self.text2image_pipeline = Text2ImagePipeline(config) | ||
self.setup_routes() | ||
|
||
def setup_routes(self): | ||
""" | ||
Set up the routes for the Text2Image service. | ||
""" | ||
self.app.get("/", response_class=HTMLResponse)(self.gui) | ||
self.app.mount( | ||
"/static", | ||
StaticFiles(directory=Path(__file__).parent / "app", html=True), | ||
name="static", | ||
) | ||
self.app.add_websocket_route("/ws", self.generate) | ||
|
||
async def generate(self, websocket: WebSocket): | ||
""" | ||
Generate an image from a text prompt using the Text2Image pipeline. | ||
""" | ||
|
||
await websocket.accept() | ||
|
||
try: | ||
while True: | ||
await asyncio.sleep(0.1) | ||
data = await websocket.receive_text() | ||
print (data) | ||
gen_config = Text2ImageGenerationConfig.model_validate_json(data) | ||
request_id = str(uuid.uuid4()) | ||
async for step_info in self.text2image_pipeline.generate(request_id, gen_config): | ||
if step_info['type'] == 'waiting': | ||
await websocket.send_json(step_info) | ||
continue | ||
latents = step_info['image'] | ||
|
||
# Конвертируем латенты в base64 | ||
image_io = BytesIO() | ||
latents.save(image_io, 'PNG') | ||
dataurl = 'data:image/png;base64,' + base64.b64encode(image_io.getvalue()).decode('ascii') | ||
# Отправляем инфу о прогрессе и латенты | ||
await websocket.send_json({ | ||
"type": "generation_step", | ||
"step": step_info['step'], | ||
"total_steps": step_info['total_steps'], | ||
"latents": dataurl, | ||
"shape": latents.size | ||
}) | ||
|
||
await websocket.send_json({ | ||
"type": "generation_complete" | ||
}) | ||
except Exception as e: | ||
logging.error(e) | ||
traceback.print_exc() | ||
await websocket.send_json( | ||
{ | ||
"type": "error", | ||
"message": str(e), | ||
} | ||
) | ||
finally: | ||
await websocket.close() | ||
|
||
def gui(self): | ||
""" | ||
Get the GUI for the Text2Image service. | ||
""" | ||
print("GUI") | ||
path = Path(__file__).parent / "app" / "gui.html" | ||
return FileResponse(path) | ||
|
||
|
||
@hydra.main(config_name="config") | ||
def main(config: Text2ImagePipelineConfig): | ||
""" | ||
The main function for the Text2Image service. | ||
""" | ||
text2image_app = Text2ImageApp(config) | ||
uvicorn.run(text2image_app.app, host="localhost", port=8002, log_level="debug") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
main() # pylint: disable=no-value-for-parameter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
<!DOCTYPE html> | ||
<html lang="en"> | ||
<head> | ||
<meta charset="UTF-8"> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||
<title>Stable Diffusion Streaming Interface</title> | ||
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"> | ||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css"> | ||
<link rel="stylesheet" href="/static/styles.css"> | ||
</head> | ||
<body> | ||
<div class="container"> | ||
<div class="generation-container"> | ||
<div class="generation-header"> | ||
<h1 id="title">Stable Diffusion Streaming</h1> | ||
<button id="menu-toggle" class="btn btn-secondary menu-toggle"><i class="fas fa-bars"></i></button> | ||
</div> | ||
<div id="image-output" class="image-output"></div> | ||
<div class="input-container"> | ||
<textarea id="prompt" class="form-control" placeholder="Enter your prompt" rows="2"></textarea> | ||
<button onclick="sendMessage()" class="btn btn-primary" id="send-btn">Generate</button> | ||
</div> | ||
<textarea id="negative_prompt" class="form-control" placeholder="Enter your negative prompt" rows="2"></textarea> | ||
</div> | ||
|
||
<div class="config-container"> | ||
<div class="form-group"> | ||
<label for="num_inference_steps">Inference Steps:</label> | ||
<input type="number" id="num_inference_steps" class="form-control" value="50"> | ||
</div> | ||
<div class="form-group"> | ||
<label for="guidance_scale">Guidance Scale:</label> | ||
<input type="number" id="guidance_scale" class="form-control" value="7.5" step="0.1"> | ||
</div> | ||
<div class="form-group"> | ||
<label for="width">Width:</label> | ||
<input type="number" id="width" class="form-control" value="512" step="8"> | ||
</div> | ||
<div class="form-group"> | ||
<label for="height">Height:</label> | ||
<input type="number" id="height" class="form-control" value="512" step="8"> | ||
</div> | ||
<div class="form-group"> | ||
<label for="seed">Seed:</label> | ||
<input type="number" id="seed" class="form-control" value="-1"> | ||
</div> | ||
<button type="button" class="btn btn-secondary" onclick="resetForm()">Reset</button> | ||
</div> | ||
</div> | ||
<script> | ||
const WEBSOCKET_URL = "/ws"; | ||
const ABORT_URL = "/abort"; | ||
</script> | ||
<script src="/static/script.js"></script> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
let ws = new WebSocket(WEBSOCKET_URL); | ||
let currentRequestID = ''; | ||
|
||
ws.onopen = function () { | ||
console.log("WebSocket connection established"); | ||
}; | ||
|
||
ws.onmessage = function (event) { | ||
const data = JSON.parse(event.data); | ||
console.log('Message received:', data); | ||
if (data["type"] == "generation_complete") { | ||
enableGenerateButton(); | ||
return; | ||
}; | ||
if (data["type"] == "waiting") { | ||
msg = data["msg"]; | ||
console.log(msg); | ||
return; | ||
} | ||
|
||
imgElement = document.getElementById('image-output'); | ||
let img = imgElement.querySelector('img'); | ||
if (!img) { | ||
img = document.createElement('img'); | ||
imgElement.appendChild(img); | ||
} | ||
img.src = data['latents']; | ||
console.log('Image updated:', data); | ||
}; | ||
|
||
ws.onclose = function (event) { | ||
console.log("WebSocket connection closed"); | ||
}; | ||
|
||
function decodeBase64Image(base64String) { | ||
// Remove data URL prefix if present | ||
const base64Data = base64String.replace(/^data:image\/\w+;base64,/, ''); | ||
|
||
// Decode base64 | ||
const binaryString = atob(base64Data); | ||
|
||
// Create Uint8Array | ||
const uint8Array = new Uint8Array(binaryString.length); | ||
for (let i = 0; i < binaryString.length; i++) { | ||
uint8Array[i] = binaryString.charCodeAt(i); | ||
} | ||
|
||
// Create Blob | ||
return new Blob([uint8Array], { type: 'image/png' }); | ||
}; | ||
|
||
|
||
function sendMessage() { | ||
const prompt = document.getElementById('prompt').value; | ||
const negative_prompt = document.getElementById('negative_prompt').value; | ||
const num_inference_steps = document.getElementById('num_inference_steps').value; | ||
const guidance_scale = document.getElementById('guidance_scale').value; | ||
const width = document.getElementById('width').value; | ||
const height = document.getElementById('height').value; | ||
const seed = document.getElementById('seed').value; | ||
|
||
const message = { | ||
prompt: prompt, | ||
negative_prompt: negative_prompt, | ||
num_inference_steps: parseInt(num_inference_steps), | ||
guidance_scale: parseFloat(guidance_scale), | ||
width: parseInt(width), | ||
height: parseInt(height), | ||
seed: parseInt(seed) | ||
}; | ||
|
||
ws.send(JSON.stringify(message)); | ||
console.log('Message sent:', message); | ||
disableGenerateButton(); | ||
}; | ||
|
||
function updateStatus(step, total_steps) { | ||
const statusElement = document.getElementById('status'); | ||
if (!statusElement) { | ||
const statusElement = document.createElement('div'); | ||
statusElement.id = 'status'; | ||
document.getElementById('image-output').appendChild(statusElement); | ||
} | ||
statusElement.textContent = `Generating: ${step}/${total_steps}`; | ||
} | ||
|
||
|
||
function updateImage(base64Image) { | ||
const img = document.getElementById('generated-image'); | ||
if (!img) { | ||
const img = document.createElement('img'); | ||
img.id = 'generated-image'; | ||
document.body.appendChild(img); | ||
} | ||
img.src = `data:image/png;base64,${base64Image}`; | ||
} | ||
|
||
function disableGenerateButton() { | ||
const button = document.getElementById('send-btn'); | ||
button.disabled = true; | ||
button.textContent = 'Generating...'; | ||
} | ||
|
||
function enableGenerateButton() { | ||
const button = document.getElementById('send-btn'); | ||
button.disabled = false; | ||
button.textContent = 'Generate'; | ||
} | ||
|
||
function resetForm() { | ||
document.getElementById('num_inference_steps').value = '50'; | ||
document.getElementById('guidance_scale').value = '7.5'; | ||
document.getElementById('width').value = '512'; | ||
document.getElementById('height').value = '512'; | ||
document.getElementById('seed').value = '-1'; | ||
} | ||
|
||
const menuToggle = document.getElementById('menu-toggle'); | ||
const configContainer = document.querySelector('.config-container'); | ||
|
||
menuToggle.addEventListener('click', () => { | ||
configContainer.classList.toggle('show'); | ||
}); | ||
|
||
document.addEventListener('click', (event) => { | ||
if (!configContainer.contains(event.target) && !menuToggle.contains(event.target)) { | ||
configContainer.classList.remove('show'); | ||
} | ||
}); |
Oops, something went wrong.