Skip to content

Commit

Permalink
fix diffusers 0.4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
LowinLi committed Oct 7, 2022
1 parent c9c9351 commit 945107e
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 57 deletions.
3 changes: 2 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ src/stable-diffusion-streamlit/pages/model/result
src/stable-diffusion-streamlit/pages/model/onnx
__pycache__
docker/.env
tag.sh
tag.sh
docker/volume
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ result
onnx
__pycache__
docker/.env
tag.sh
tag.sh
docker/volume
15 changes: 15 additions & 0 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
version: "2.3"
services:
stable-diffusion-streamlit-onnxquantized:
container_name: stable-diffusion-streamlit-onnxquantized
image: lowinli98/stable-diffusion-streamlit-onnxquantized:v0.1
expose:
- 8501
ports:
- "8501:8501"
environment:
- APP_TITLE=Stable Diffusion Streamlit
restart: always
volumes:
- /etc/localtime:/etc/localtime
- ./volume:/app/pages/model/result
4 changes: 1 addition & 3 deletions docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ ftfy==6.1.1
onnx==1.12.0
onnxruntime==1.12.1
streamlit==1.13.0
streamlit-image-comparison==0.0.2
transformers==4.22.2
diffusers==0.4.0
torch==1.10.0+cpu
pandas==1.4.1
torch==1.10.0+cpu
10 changes: 8 additions & 2 deletions src/stable-diffusion-streamlit/pages/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
import json

root = os.getcwd()

if root[-3:] == "src":
last_dir = os.path.split(root)[-1]
if last_dir == "stable-diffusion-streamlit":
model_dir = os.path.join(root, "pages/model/onnx")
result_dir = os.path.join(root, "pages/model/result")
elif last_dir == "pages":
model_dir = os.path.join(root, "model/onnx")
result_dir = os.path.join(root, "model/result")
elif last_dir == "app":
model_dir = os.path.join(root, "pages/model/onnx")
result_dir = os.path.join(root, "pages/model/result")
else:
Expand Down
35 changes: 18 additions & 17 deletions src/stable-diffusion-streamlit/pages/model/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@

def quant():
for root, dirs, filenames in os.walk("./onnx"):
for filename in filenames:
if "model.onnx" in filenames:
if "weights.pb" in filenames:
external_data = True
else:
external_data = False
quantize_dynamic(
model_input=os.path.join(root, "model.onnx"),
model_output=os.path.join(root, "model.onnx"), # 量化后直接覆盖原onnx文件
per_channel=True,
reduce_range=True,
weight_type=QuantType.QUInt8,
optimize_model=True,
use_external_data_format=external_data,
)
print("Quantized model saved at: ", os.path.join(root, "model.onnx"))

if "model.onnx" in filenames:
if "weights.pb" in filenames:
external_data = True
else:
external_data = False
quantize_dynamic(
model_input=os.path.join(root, "model.onnx"),
model_output=os.path.join(root, "model.onnx"), # 量化后直接覆盖原onnx文件
per_channel=True,
reduce_range=True,
weight_type=QuantType.QUInt8,
optimize_model=True,
use_external_data_format=external_data,
)
print("Quantized model saved at: ", os.path.join(root, "model.onnx"))
if "weights.pb" in filenames:
os.remove(os.path.join(root, "weights.pb"))
print("Removed weights.pb")

if __name__ == "__main__":
quant()
33 changes: 20 additions & 13 deletions src/stable-diffusion-streamlit/pages/文字转图片.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import streamlit as st
import pandas as pd
import numpy as np
import time
import threading
Expand All @@ -25,15 +24,6 @@
ce, c1, ce, c2, c3 = st.columns([0.07, 1, 0.07, 5, 0.07])
with c1:
st.subheader("参数配置", anchor=None)
guidance_scale = st.slider(
"指导参数(guidance_scale)",
min_value=0.0,
max_value=30.0,
value=7.0,
step=0.1,
help="值越大,约接近输入文字 \n Defined in https://arxiv.org/abs/2207.12598 \n Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.",
label_visibility="visible",
)
num_inference_steps = st.slider(
"生成轮数(num_inference_steps)",
min_value=5,
Expand All @@ -43,6 +33,15 @@
label_visibility="visible",
help="约大生成图片质量越高,但是速度越慢 \n The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.",
)
guidance_scale = st.slider(
"指导参数(guidance_scale)",
min_value=0.0,
max_value=30.0,
value=7.0,
step=0.1,
help="值越大,约接近输入文字 \n Defined in https://arxiv.org/abs/2207.12598 \n Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.",
label_visibility="visible",
)
height = st.slider(
"高度像素(height)",
min_value=64,
Expand Down Expand Up @@ -86,6 +85,14 @@
disabled=False,
label_visibility="visible",
)
negative_prompt = st.text_area(
"输入不要生成的文字描述,不填为不使用",
value="",
help="The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).",
disabled=False,
label_visibility="visible",
)
negative_prompt = None
submit_button = st.form_submit_button("开始生成", help=None, args=None, kwargs=None)
my_bar = st.progress(0)
if not submit_button:
Expand All @@ -98,8 +105,8 @@
width,
num_inference_steps,
guidance_scale,
eta,
None,
negative_prompt,
eta
)
t = PipelineThread(func=quant_pipe, args=args)
t.start()
Expand All @@ -110,7 +117,7 @@
else:
counter = 0
progress = min(
t.func.scheduler.counter / (t.func.scheduler.num_inference_steps + 1),
t.func.scheduler.counter / (num_inference_steps + 1),
1.0,
)
my_bar.progress(progress)
Expand Down
42 changes: 22 additions & 20 deletions src/stable-diffusion-streamlit/pages/画廊.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,28 @@
import os
import json

tab1, tab2 = st.tabs(["文字转图片"])
# tab1 = st.tabs(["文字转图片"])

with tab1:
list_uid = os.listdir("pages/model/result/text2image")
list_uid = sorted(list_uid, reverse=True)
# with tab1:
result_dir = "pages/model/result/text2image"
os.makedirs(result_dir, exist_ok=True)
list_uid = os.listdir(result_dir)
list_uid = sorted(list_uid, reverse=True)

for uid in list_uid:
try:
with open(
os.path.join("pages/model/result/text2image", uid, "config.json"), "r"
) as f:
config = json.load(f)
for uid in list_uid:
try:
with open(
os.path.join(result_dir, uid, "config.json"), "r"
) as f:
config = json.load(f)

with st.container():
st.caption(uid)
st.image(
f"pages/model/result/text2image/{uid}/image.png",
caption=str(config["text_prompt"]),
)
st.json(config, expanded=False)
st.markdown("---")
except:
pass
with st.container():
st.caption(uid)
st.image(
os.path.join(result_dir, uid, "image.png"),
caption=str(config["text_prompt"]),
)
st.json(config, expanded=False)
st.markdown("---")
except:
pass

0 comments on commit 945107e

Please sign in to comment.