Skip to content

Commit

Permalink
small bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Sep 5, 2024
1 parent 41bc9fe commit 1df2fed
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 44 deletions.
58 changes: 33 additions & 25 deletions src/AGISwarm/text2image_ms/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ def setup_routes(self):
self.ws_router.post("/abort")(self.abort)
self.app.include_router(self.ws_router)

@staticmethod
def send_image(websocket: WebSocket, image: Image.Image, **kwargs):
"""
Send an image to the client.
"""
image_io = BytesIO()
image.save(image_io, "PNG")
dataurl = "data:image/png;base64," + base64.b64encode(
image_io.getvalue()
).decode("ascii")
asyncio_run(
websocket.send_json(
{
"image": dataurl,
"shape": image.size,
}
| kwargs
)
)

@staticmethod
def diffusion_pipeline_step_callback(
websocket: WebSocket,
Expand All @@ -95,31 +115,18 @@ def diffusion_pipeline_step_callback(
"""Callback для StableDiffusionPipeline"""
if abort_event.is_set():
raise asyncio.CancelledError("Diffusion pipeline aborted")

if step == 0 or step != total_steps and step % latent_update_frequency != 0:
return {"latents": callback_kwargs["latents"]}

with torch.no_grad():
image = pipeline.decode_latents(callback_kwargs["latents"].clone())[0]
image = Image.fromarray((image * 255).astype(np.uint8))

image_io = BytesIO()
image.save(image_io, "PNG")
dataurl = "data:image/png;base64," + base64.b64encode(
image_io.getvalue()
).decode("ascii")

asyncio_run(
websocket.send_json(
{
"request_id": request_id,
"status": RequestStatus.RUNNING,
"step": step,
"total_steps": total_steps,
"latents": dataurl,
"shape": image.size,
}
)
Text2ImageApp.send_image(
websocket,
image,
request_id=request_id,
status=RequestStatus.RUNNING,
step=step,
total_steps=total_steps,
)
return {"latents": callback_kwargs["latents"]}

Expand Down Expand Up @@ -152,7 +159,7 @@ async def generate(self, websocket: WebSocket):
request_id,
abort_event,
gen_config.num_inference_steps,
self.latent_update_frequency
self.latent_update_frequency,
)

# Start the generation task
Expand All @@ -161,10 +168,11 @@ async def generate(self, websocket: WebSocket):
gen_config, callback_on_step_end
):
if "status" not in step_info: # Task's return value.
await websocket.send_json(
{
"status": RequestStatus.FINISHED,
}
Text2ImageApp.send_image(
websocket,
step_info["image"],
request_id=request_id,
status=RequestStatus.FINISHED,
)
break
if (
Expand Down
26 changes: 17 additions & 9 deletions src/AGISwarm/text2image_ms/app/script.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ws.onmessage = function (event) {
const data = JSON.parse(event.data);
currentRequestID = data["request_id"];
console.log('Message received:', data);
if (data["status"] == "finished" || data["status"] == "aborted") {
if(data["status"] == "aborted") {
enableGenerateButton();
return;
};
Expand All @@ -26,13 +26,17 @@ ws.onmessage = function (event) {
enableGenerateButton();
return;
}
if (data["status"] == "finished"){
updateStatus("Finished", "Finished");
enableGenerateButton();
}
imgElement = document.getElementById('image-output');
let img = imgElement.querySelector('img');
if (!img) {
img = document.createElement('img');
imgElement.appendChild(img);
}
img.src = data['latents'];
img.src = data['image'];
console.log('Image updated:', data);
};

Expand Down Expand Up @@ -120,7 +124,7 @@ function enterSend(event) {
// Allow new line with Ctrl+Enter
this.value += '\n';
}
}
};

document.getElementById('prompt').addEventListener('keydown', enterSend);
document.getElementById('negative_prompt').addEventListener('keydown', enterSend);
Expand All @@ -132,8 +136,12 @@ function updateStatus(step, total_steps) {
statusElement.id = 'status';
document.getElementById('image-output').appendChild(statusElement);
}
statusElement.textContent = `Generating: ${step}/${total_steps}`;
}
if (step === "Finished") {
statusElement.setAttribute(`textContent`, `Finished`);
return;
}
statusElement.setAttribute(`textContent`,`Generating: ${step}/${total_steps}`);
};


function updateImage(base64Image) {
Expand All @@ -144,27 +152,27 @@ function updateImage(base64Image) {
document.body.appendChild(img);
}
img.src = `data:image/png;base64,${base64Image}`;
}
};

function disableGenerateButton() {
document.getElementById('send-btn').style.backgroundColor = "#808080";
document.getElementById('send-btn').textContent = "Abort";
document.getElementById('send-btn').disabled = false;
}
};

function enableGenerateButton() {
document.getElementById('send-btn').style.backgroundColor = "#363d46";
document.getElementById('send-btn').textContent = "Send";
document.getElementById('send-btn').disabled = false;
}
};

function resetForm() {
document.getElementById('num_inference_steps').value = '50';
document.getElementById('guidance_scale').value = '7.5';
document.getElementById('width').value = '512';
document.getElementById('height').value = '512';
document.getElementById('seed').value = '-1';
}
};

const menuToggle = document.getElementById('menu-toggle');
const configContainer = document.querySelector('.config-container');
Expand Down
23 changes: 13 additions & 10 deletions src/AGISwarm/text2image_ms/diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
that is used to generate images from text prompts using the Stable Diffusion model.
"""


from typing import Callable, Optional

import numpy as np
Expand Down Expand Up @@ -80,13 +81,15 @@ def generate(
callback_on_step_end (Optional[Callable[[dict], None]):
The callback function to call on each step end.
"""
return self.pipeline(
prompt=gen_config.prompt,
negative_prompt=gen_config.negative_prompt,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
seed=gen_config.seed,
width=gen_config.width,
height=gen_config.height,
callback_on_step_end=callback_on_step_end,
)
return {
"image": self.pipeline(
prompt=gen_config.prompt,
negative_prompt=gen_config.negative_prompt,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
seed=gen_config.seed,
width=gen_config.width,
height=gen_config.height,
callback_on_step_end=callback_on_step_end,
)["images"][0]
}

0 comments on commit 1df2fed

Please sign in to comment.