-
Notifications
You must be signed in to change notification settings - Fork 881
/
sdxl_minimal_inference.py
345 lines (283 loc) · 13.5 KB
/
sdxl_minimal_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
# 手元で推論を行うための最低限のコード。HuggingFace/DiffusersのCLIP、schedulerとVAEを使う
# Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE
import argparse
import datetime
import math
import os
import random
from einops import repeat
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
from PIL import Image
# import open_clip
from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
# scheduler: The settings around here seem to be the same as SD1/2
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
# Time EmbeddingはDiffusersからのコピー
# Time Embedding is copied from Diffusers
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding
def get_timestep_embedding(x, outdim):
assert len(x.shape) == 2
b, dims = x.shape[0], x.shape[1]
# x = rearrange(x, "b d -> (b d)")
x = torch.flatten(x)
emb = timestep_embedding(x, outdim)
# emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim)
emb = torch.reshape(emb, (b, dims * outdim))
return emb
if __name__ == "__main__":
# 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions
# SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL
target_height = 1024
target_width = 1024
original_height = target_height
original_width = target_width
crop_top = 0
crop_left = 0
steps = 50
guidance_scale = 7
seed = None # 1
DEVICE = get_preferred_device()
DTYPE = torch.float16 # bfloat16 may work
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--prompt2", type=str, default=None)
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
if args.prompt2 is None:
args.prompt2 = args.prompt
# HuggingFaceのmodel id
text_encoder_1_name = "openai/clip-vit-large-patch14"
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
# checkpointを読み込む。モデル変換についてはそちらの関数を参照
# Load checkpoint. For model conversion, see this function
# 本体RAMが少ない場合はGPUにロードするといいかも
# If the main RAM is small, it may be better to load it on the GPU
text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu"
)
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
# In SDXL, Text Encoder 1 is also using HuggingFace's
# Text Encoder 2はSDXL本体ではopen_clipを使っている
# それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う
# 重みの変換コードはSD2とほぼ同じ
# In SDXL, Text Encoder 2 is using open_clip
# It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's
# The weight conversion code is almost the same as SD2
# VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う
# fp16でNaNが出やすいようだ
# The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different.
# NaN seems to be more likely to occur in fp16
unet.to(DEVICE, dtype=DTYPE)
unet.eval()
vae_dtype = DTYPE
if DTYPE == torch.float16:
logger.info("use float32 for vae")
vae_dtype = torch.float32
vae.to(DEVICE, dtype=vae_dtype)
vae.eval()
text_model1.to(DEVICE, dtype=DTYPE)
text_model1.eval()
text_model2.to(DEVICE, dtype=DTYPE)
text_model2.eval()
unet.set_use_memory_efficient_attention(True, False)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(True)
# Tokenizers
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
# tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name)
# LoRA
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
lora_model, weights_sd = lora.create_network_from_weights(
multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True
)
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
# scheduler
scheduler = EulerDiscreteScheduler(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
)
def generate_image(prompt, prompt2, negative_prompt, seed=None):
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
# prepare embedding
with torch.no_grad():
# vector
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# logger.info("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(
DEVICE, dtype=DTYPE
) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
# crossattn
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
def call_text_encoder(text, text2):
# text encoder 1
batch_encoding = tokenizer1(
text,
truncation=True,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(DEVICE)
with torch.no_grad():
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
text_embedding1 = enc_out["hidden_states"][11]
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
# text encoder 2
# tokens = tokenizer2(text2).to(DEVICE)
tokens = tokenizer2(
text,
truncation=True,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(DEVICE)
with torch.no_grad():
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# logger.info("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
return text_embedding, text_embedding2_pool
# cond
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
# logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
# uncond
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt)
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
text_embeddings = torch.cat([uc_ctx, c_ctx])
vector_embeddings = torch.cat([uc_vector, c_vector])
# メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# # random generator for initial noise
# generator = torch.Generator(device="cuda").manual_seed(seed)
generator = None
else:
generator = None
# get the initial random noise unless the user supplied it
# SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している
# SDXL creates latents in CPU, Diffusers creates latents in target device
latents_shape = (1, 4, target_height // 8, target_width // 8)
latents = torch.randn(
latents_shape,
generator=generator,
device="cpu",
dtype=torch.float32,
).to(DEVICE, dtype=DTYPE)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma
# set timesteps
scheduler.set_timesteps(steps, DEVICE)
# このへんはDiffusersからのコピペ
# Copy from Diffusers
timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
num_latent_input = 2
with torch.no_grad():
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = scheduler.step(noise_pred, t, latents).prev_sample
# latents = 1 / 0.18215 * latents
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
latents = latents.to(vae_dtype)
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# image = self.numpy_to_pil(image)
image = (image * 255).round().astype("uint8")
image = [Image.fromarray(im) for im in image]
# 保存して終了 save and finish
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
for i, img in enumerate(image):
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
if not args.interactive:
generate_image(args.prompt, args.prompt2, args.negative_prompt, seed)
else:
# loop for interactive
while True:
prompt = input("prompt: ")
if prompt == "":
break
prompt2 = input("prompt2: ")
if prompt2 == "":
prompt2 = prompt
negative_prompt = input("negative prompt: ")
seed = input("seed: ")
if seed == "":
seed = None
else:
seed = int(seed)
generate_image(prompt, prompt2, negative_prompt, seed)
logger.info("Done!")