generated from AGISwarm/python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from AGISwarm/basic_functions
all works
- Loading branch information
Showing
12 changed files
with
764 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ __pycache__/ | |
.cache | ||
|
||
# due to using vscode | ||
.vscode/ | ||
.vscode/ | ||
outputs/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
python -m pylint src | ||
python -m pyright src | ||
python -m black src --check | ||
python -m isort src --check-only |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
model: "runwayml/stable-diffusion-v1-5" | ||
dtype: "float32" | ||
device: "cuda:0" | ||
safety_checker: | ||
requires_safety_checker: True | ||
low_cpu_mem_usage: False | ||
|
||
hydra: | ||
job: | ||
chdir: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"pythonVersion": "3.13", // Specify the Python version to use | ||
"pythonPlatform": "Linux", // Specify the platform (Windows, Linux, or macOS) | ||
"typeCheckingMode": "basic", // Set the type checking mode (basic or strict) | ||
"reportMissingTypeStubs": false, // Report missing type stubs for imported modules | ||
"reportPrivateImportUsage": false, // Report usage of private imports | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,139 @@ | ||
"""Main.py""" | ||
""" | ||
This module is the entry point for the text2image_ms service. | ||
from . import __version__ | ||
""" | ||
|
||
import asyncio | ||
import base64 | ||
import logging | ||
import traceback | ||
from io import BytesIO | ||
from pathlib import Path | ||
|
||
def main(): | ||
"""Main function""" | ||
print(__version__) | ||
import hydra | ||
import uvicorn | ||
from AGISwarm.asyncio_queue_manager import AsyncIOQueueManager, RequestStatus | ||
from fastapi import APIRouter, FastAPI, WebSocket | ||
from fastapi.responses import FileResponse, HTMLResponse | ||
from fastapi.staticfiles import StaticFiles | ||
from pydantic.main import BaseModel | ||
|
||
from .diffusion_pipeline import Text2ImagePipeline | ||
from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig | ||
|
||
|
||
class Text2ImageApp: | ||
""" | ||
A class to represent the Text2Image service. | ||
""" | ||
|
||
def __init__(self, config: Text2ImagePipelineConfig): | ||
self.app = FastAPI() | ||
self.setup_routes() | ||
self.queue_manager = AsyncIOQueueManager() | ||
self.text2image_pipeline = Text2ImagePipeline(config) | ||
|
||
def setup_routes(self): | ||
""" | ||
Set up the routes for the Text2Image service. | ||
""" | ||
self.app.get("/", response_class=HTMLResponse)(self.gui) | ||
self.app.mount( | ||
"/static", | ||
StaticFiles(directory=Path(__file__).parent / "app", html=True), | ||
name="static", | ||
) | ||
|
||
self.ws_router = APIRouter() | ||
self.ws_router.add_websocket_route("/ws", self.generate) | ||
self.ws_router.post("/abort")(self.abort) | ||
self.app.include_router(self.ws_router) | ||
|
||
async def generate(self, websocket: WebSocket): | ||
""" | ||
Generate an image from a text prompt using the Text2Image pipeline. | ||
""" | ||
|
||
await websocket.accept() | ||
|
||
try: | ||
while True: | ||
await asyncio.sleep(0.01) | ||
data = await websocket.receive_text() | ||
print(data) | ||
gen_config = Text2ImageGenerationConfig.model_validate_json(data) | ||
generator = self.queue_manager.queued_generator( | ||
self.text2image_pipeline.generate | ||
) | ||
request_id = generator.request_id | ||
interrupt_event = self.queue_manager.abort_map[request_id] | ||
|
||
async for step_info in generator( | ||
gen_config, interrupt_event=interrupt_event | ||
): | ||
await asyncio.sleep(0.01) | ||
print(step_info) | ||
if step_info["status"] == RequestStatus.WAITING: | ||
await websocket.send_json(step_info) | ||
continue | ||
if step_info["status"] != RequestStatus.RUNNING: | ||
await websocket.send_json(step_info) | ||
break | ||
latents = step_info["image"] | ||
image_io = BytesIO() | ||
latents.save(image_io, "PNG") | ||
dataurl = "data:image/png;base64," + base64.b64encode( | ||
image_io.getvalue() | ||
).decode("ascii") | ||
await websocket.send_json( | ||
{ | ||
"request_id": request_id, | ||
"status": RequestStatus.RUNNING, | ||
"step": step_info["step"], | ||
"total_steps": step_info["total_steps"], | ||
"latents": dataurl, | ||
"shape": latents.size, | ||
} | ||
) | ||
except Exception as e: # pylint: disable=broad-except | ||
logging.error(e) | ||
traceback.print_exc() | ||
await websocket.send_json( | ||
{ | ||
"status": RequestStatus.ERROR, | ||
"message": str(e), ### loggging | ||
} | ||
) | ||
finally: | ||
await websocket.close() | ||
|
||
class AbortRequest(BaseModel): | ||
"""Abort request""" | ||
|
||
request_id: str | ||
|
||
async def abort(self, request: AbortRequest): | ||
"""Abort generation""" | ||
print(f"Aborting request {request.request_id}") | ||
await self.queue_manager.abort_task(request.request_id) | ||
|
||
def gui(self): | ||
""" | ||
Get the GUI for the Text2Image service. | ||
""" | ||
print("GUI") | ||
path = Path(__file__).parent / "app" / "gui.html" | ||
return FileResponse(path) | ||
|
||
|
||
@hydra.main(config_name="config") | ||
def main(config: Text2ImagePipelineConfig): | ||
""" | ||
The main function for the Text2Image service. | ||
""" | ||
text2image_app = Text2ImageApp(config) | ||
uvicorn.run(text2image_app.app, host="127.0.0.1", port=8002, log_level="debug") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
main() # pylint: disable=no-value-for-parameter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
<!DOCTYPE html> | ||
<html lang="en"> | ||
<head> | ||
<meta charset="UTF-8"> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||
<title>Stable Diffusion Streaming Interface</title> | ||
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"> | ||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css"> | ||
<link rel="stylesheet" href="/static/styles.css"> | ||
</head> | ||
<body> | ||
<div class="container"> | ||
<div class="generation-container"> | ||
<div class="generation-header"> | ||
<h1 id="title">Stable Diffusion Streaming</h1> | ||
<button id="menu-toggle" class="btn btn-secondary menu-toggle"><i class="fas fa-bars"></i></button> | ||
</div> | ||
<div id="image-output" class="image-output"></div> | ||
<div class="input-container"> | ||
<textarea id="prompt" class="form-control" placeholder="Enter your prompt" rows="2"></textarea> | ||
<button onclick="sendButtonClick()" class="btn btn-primary" id="send-btn">Send</button> | ||
</div> | ||
<textarea id="negative_prompt" class="form-control" placeholder="Enter your negative prompt" rows="2"></textarea> | ||
</div> | ||
|
||
<div class="config-container"> | ||
<div class="form-group"> | ||
<label for="num_inference_steps">Inference Steps:</label> | ||
<input type="number" id="num_inference_steps" class="form-control" value="50"> | ||
</div> | ||
<div class="form-group"> | ||
<label for="guidance_scale">Guidance Scale:</label> | ||
<input type="number" id="guidance_scale" class="form-control" value="7.5" step="0.1"> | ||
</div> | ||
<div class="form-group"> | ||
<label for="width">Width:</label> | ||
<input type="number" id="width" class="form-control" value="512" step="8"> | ||
</div> | ||
<div class="form-group"> | ||
<label for="height">Height:</label> | ||
<input type="number" id="height" class="form-control" value="512" step="8"> | ||
</div> | ||
<div class="form-group"> | ||
<label for="seed">Seed:</label> | ||
<input type="number" id="seed" class="form-control" value="-1"> | ||
</div> | ||
<button type="button" class="btn btn-secondary" onclick="resetForm()">Reset</button> | ||
</div> | ||
</div> | ||
<script> | ||
const WEBSOCKET_URL = "/ws"; | ||
const ABORT_URL = "/abort"; | ||
</script> | ||
<script src="/static/script.js"></script> | ||
</body> | ||
</html> |
Oops, something went wrong.