Skip to content

Commit

Permalink
Merge pull request #2 from AGISwarm/basic_functions
Browse files Browse the repository at this point in the history
all works
  • Loading branch information
DenisDiachkov authored Sep 4, 2024
2 parents f43e42b + a03339f commit f29c36d
Show file tree
Hide file tree
Showing 12 changed files with 764 additions and 9 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build-analysis-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ jobs:
steps:
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.11"
- name: Lint (Pylint)
run: source .venv/bin/activate && pylint src
- name: Format check (Black)
Expand All @@ -53,5 +55,7 @@ jobs:
steps:
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.11"
- name: Unittests (Pytest)
run: source .venv/bin/activate && pytest tests/unittests
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ __pycache__/
.cache

# due to using vscode
.vscode/
.vscode/
outputs/
4 changes: 4 additions & 0 deletions CICD/analyze.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python -m pylint src
python -m pyright src
python -m black src --check
python -m isort src --check-only
10 changes: 10 additions & 0 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
model: "runwayml/stable-diffusion-v1-5"
dtype: "float32"
device: "cuda:0"
safety_checker:
requires_safety_checker: True
low_cpu_mem_usage: False

hydra:
job:
chdir: false
22 changes: 20 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,25 @@ dynamic = ["version"]
description = "Python template project"
readme = "README.md"

requires-python = ">=3.11"
requires-python = ">=3.10"
license = { file = "LICENSE.txt" }
keywords = ["sample", "setuptools", "development"]
classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = ['numpy']
dependencies = [
'numpy<2.0.0',
"flask",
"pillow",
"diffusers",
"torch",
"deepspeed",
"hydra-core",
"fastapi",
"pillow",
"uvicorn",
"AGISwarm.asyncio_queue_manager"
]
[project.optional-dependencies]
test = ['pytest']
analyze = ['pyright', 'pylint', 'bandit', 'black', 'isort']
Expand All @@ -31,3 +43,9 @@ where = ["src"]

[tool.setuptools.package-data]
text2image_ms = ["py.typed"]

[tool.isort]
profile = "black"

[tool.pylint.'MESSAGES CONTROL']
disable = "wrong-import-order"
7 changes: 7 additions & 0 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"pythonVersion": "3.13", // Specify the Python version to use
"pythonPlatform": "Linux", // Specify the platform (Windows, Linux, or macOS)
"typeCheckingMode": "basic", // Set the type checking mode (basic or strict)
"reportMissingTypeStubs": false, // Report missing type stubs for imported modules
"reportPrivateImportUsage": false, // Report usage of private imports
}
139 changes: 133 additions & 6 deletions src/AGISwarm/text2image_ms/__main__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,139 @@
"""Main.py"""
"""
This module is the entry point for the text2image_ms service.
from . import __version__
"""

import asyncio
import base64
import logging
import traceback
from io import BytesIO
from pathlib import Path

def main():
"""Main function"""
print(__version__)
import hydra
import uvicorn
from AGISwarm.asyncio_queue_manager import AsyncIOQueueManager, RequestStatus
from fastapi import APIRouter, FastAPI, WebSocket
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic.main import BaseModel

from .diffusion_pipeline import Text2ImagePipeline
from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig


class Text2ImageApp:
"""
A class to represent the Text2Image service.
"""

def __init__(self, config: Text2ImagePipelineConfig):
self.app = FastAPI()
self.setup_routes()
self.queue_manager = AsyncIOQueueManager()
self.text2image_pipeline = Text2ImagePipeline(config)

def setup_routes(self):
"""
Set up the routes for the Text2Image service.
"""
self.app.get("/", response_class=HTMLResponse)(self.gui)
self.app.mount(
"/static",
StaticFiles(directory=Path(__file__).parent / "app", html=True),
name="static",
)

self.ws_router = APIRouter()
self.ws_router.add_websocket_route("/ws", self.generate)
self.ws_router.post("/abort")(self.abort)
self.app.include_router(self.ws_router)

async def generate(self, websocket: WebSocket):
"""
Generate an image from a text prompt using the Text2Image pipeline.
"""

await websocket.accept()

try:
while True:
await asyncio.sleep(0.01)
data = await websocket.receive_text()
print(data)
gen_config = Text2ImageGenerationConfig.model_validate_json(data)
generator = self.queue_manager.queued_generator(
self.text2image_pipeline.generate
)
request_id = generator.request_id
interrupt_event = self.queue_manager.abort_map[request_id]

async for step_info in generator(
gen_config, interrupt_event=interrupt_event
):
await asyncio.sleep(0.01)
print(step_info)
if step_info["status"] == RequestStatus.WAITING:
await websocket.send_json(step_info)
continue
if step_info["status"] != RequestStatus.RUNNING:
await websocket.send_json(step_info)
break
latents = step_info["image"]
image_io = BytesIO()
latents.save(image_io, "PNG")
dataurl = "data:image/png;base64," + base64.b64encode(
image_io.getvalue()
).decode("ascii")
await websocket.send_json(
{
"request_id": request_id,
"status": RequestStatus.RUNNING,
"step": step_info["step"],
"total_steps": step_info["total_steps"],
"latents": dataurl,
"shape": latents.size,
}
)
except Exception as e: # pylint: disable=broad-except
logging.error(e)
traceback.print_exc()
await websocket.send_json(
{
"status": RequestStatus.ERROR,
"message": str(e), ### loggging
}
)
finally:
await websocket.close()

class AbortRequest(BaseModel):
"""Abort request"""

request_id: str

async def abort(self, request: AbortRequest):
"""Abort generation"""
print(f"Aborting request {request.request_id}")
await self.queue_manager.abort_task(request.request_id)

def gui(self):
"""
Get the GUI for the Text2Image service.
"""
print("GUI")
path = Path(__file__).parent / "app" / "gui.html"
return FileResponse(path)


@hydra.main(config_name="config")
def main(config: Text2ImagePipelineConfig):
"""
The main function for the Text2Image service.
"""
text2image_app = Text2ImageApp(config)
uvicorn.run(text2image_app.app, host="127.0.0.1", port=8002, log_level="debug")


if __name__ == "__main__":
main()
main() # pylint: disable=no-value-for-parameter
56 changes: 56 additions & 0 deletions src/AGISwarm/text2image_ms/app/gui.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Stable Diffusion Streaming Interface</title>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">
<link rel="stylesheet" href="/static/styles.css">
</head>
<body>
<div class="container">
<div class="generation-container">
<div class="generation-header">
<h1 id="title">Stable Diffusion Streaming</h1>
<button id="menu-toggle" class="btn btn-secondary menu-toggle"><i class="fas fa-bars"></i></button>
</div>
<div id="image-output" class="image-output"></div>
<div class="input-container">
<textarea id="prompt" class="form-control" placeholder="Enter your prompt" rows="2"></textarea>
<button onclick="sendButtonClick()" class="btn btn-primary" id="send-btn">Send</button>
</div>
<textarea id="negative_prompt" class="form-control" placeholder="Enter your negative prompt" rows="2"></textarea>
</div>

<div class="config-container">
<div class="form-group">
<label for="num_inference_steps">Inference Steps:</label>
<input type="number" id="num_inference_steps" class="form-control" value="50">
</div>
<div class="form-group">
<label for="guidance_scale">Guidance Scale:</label>
<input type="number" id="guidance_scale" class="form-control" value="7.5" step="0.1">
</div>
<div class="form-group">
<label for="width">Width:</label>
<input type="number" id="width" class="form-control" value="512" step="8">
</div>
<div class="form-group">
<label for="height">Height:</label>
<input type="number" id="height" class="form-control" value="512" step="8">
</div>
<div class="form-group">
<label for="seed">Seed:</label>
<input type="number" id="seed" class="form-control" value="-1">
</div>
<button type="button" class="btn btn-secondary" onclick="resetForm()">Reset</button>
</div>
</div>
<script>
const WEBSOCKET_URL = "/ws";
const ABORT_URL = "/abort";
</script>
<script src="/static/script.js"></script>
</body>
</html>
Loading

0 comments on commit f29c36d

Please sign in to comment.