Skip to content

Commit

Permalink
[SD3 dreambooth-lora training] small updates + bug fixes (huggingface…
Browse files Browse the repository at this point in the history
…#9682)

* add latent caching + smol updates

* update license

* replace with free_memory

* add --upcast_before_saving to allow saving transformer weights in lower precision

* fix models to accumulate

* fix mixed precision issue as proposed in huggingface#9565

* smol update to readme

* style

* fix caching latents

* style

* add tests for latent caching

* style

* fix latent caching

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
linoytsaban and sayakpaul authored Oct 16, 2024
1 parent cef4f65 commit ee4ab23
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 24 deletions.
2 changes: 1 addition & 1 deletion examples/dreambooth/README_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--learning_rate=4e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
Expand Down
33 changes: 33 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ def test_dreambooth_lora_text_encoder_sd3(self):
)
self.assertTrue(starts_with_expected_prefix)

def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
Expand Down
73 changes: 63 additions & 10 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def save_model_card(
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
license="other",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
Expand Down Expand Up @@ -186,7 +186,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -608,6 +608,12 @@ def parse_args(input_args=None):
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--report_to",
type=str,
Expand All @@ -628,6 +634,15 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
Expand Down Expand Up @@ -1394,6 +1409,16 @@ def load_model_hook(models, input_dir):
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warning(
f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate

optimizer = optimizer_class(
params_to_optimize,
Expand Down Expand Up @@ -1440,6 +1465,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds

# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
Expand Down Expand Up @@ -1484,6 +1512,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)

vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)

if args.validation_prompt is None:
del vae
free_memory()

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand All @@ -1500,7 +1543,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
power=args.lr_power,
)

# Prepare everything with our `accelerator`.
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
(
Expand Down Expand Up @@ -1607,8 +1649,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
if args.train_text_encoder:
models_to_accumulate.extend([text_encoder_one, text_encoder_two])
with accelerator.accumulate(models_to_accumulate):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]

# encode batch prompts when custom prompts are provided for each image -
Expand Down Expand Up @@ -1639,8 +1682,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()

model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)

# Sample noise that we'll add to the latents
Expand Down Expand Up @@ -1773,6 +1821,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand All @@ -1793,15 +1843,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
epoch=epoch,
torch_dtype=weight_dtype,
)

del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
transformer = transformer.to(torch.float32)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

if args.train_text_encoder:
Expand Down
23 changes: 10 additions & 13 deletions examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import argparse
import copy
import gc
import itertools
import logging
import math
Expand Down Expand Up @@ -51,7 +50,7 @@
StableDiffusion3Pipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import (
check_min_version,
is_wandb_available,
Expand Down Expand Up @@ -119,7 +118,7 @@ def save_model_card(
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
license="other",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
Expand Down Expand Up @@ -164,7 +163,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand All @@ -190,8 +189,7 @@ def log_validation(
)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

return images

Expand Down Expand Up @@ -1065,8 +1063,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1386,9 +1383,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1708,6 +1703,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
text_encoder_three.to(weight_dtype)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand All @@ -1730,8 +1728,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
torch.cuda.empty_cache()
gc.collect()
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down

0 comments on commit ee4ab23

Please sign in to comment.