From ee4ab23892bbd31836a450ea11a250db492d40f6 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Wed, 16 Oct 2024 11:13:37 +0300 Subject: [PATCH] [SD3 dreambooth-lora training] small updates + bug fixes (#9682) * add latent caching + smol updates * update license * replace with free_memory * add --upcast_before_saving to allow saving transformer weights in lower precision * fix models to accumulate * fix mixed precision issue as proposed in https://github.com/huggingface/diffusers/pull/9565 * smol update to readme * style * fix caching latents * style * add tests for latent caching * style * fix latent caching --------- Co-authored-by: Sayak Paul --- examples/dreambooth/README_sd3.md | 2 +- .../dreambooth/test_dreambooth_lora_sd3.py | 33 +++++++++ .../dreambooth/train_dreambooth_lora_sd3.py | 73 ++++++++++++++++--- examples/dreambooth/train_dreambooth_sd3.py | 23 +++--- 4 files changed, 107 insertions(+), 24 deletions(-) diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md index 6f41c395629a..a340be350db8 100644 --- a/examples/dreambooth/README_sd3.md +++ b/examples/dreambooth/README_sd3.md @@ -136,7 +136,7 @@ accelerate launch train_dreambooth_lora_sd3.py \ --resolution=512 \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --learning_rate=4e-4 \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py index 518738b78246..ec323be4143e 100644 --- a/examples/dreambooth/test_dreambooth_lora_sd3.py +++ b/examples/dreambooth/test_dreambooth_lora_sd3.py @@ -103,6 +103,39 @@ def test_dreambooth_lora_text_encoder_sd3(self): ) self.assertTrue(starts_with_expected_prefix) + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 02f5a7ee0f7a..8d0b6853eeec 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -140,7 +140,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="openrail++", + license="other", base_model=base_model, prompt=instance_prompt, model_description=model_description, @@ -186,7 +186,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference @@ -608,6 +608,12 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) parser.add_argument( "--report_to", type=str, @@ -628,6 +634,15 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) parser.add_argument( "--prior_generation_precision", type=str, @@ -1394,6 +1409,16 @@ def load_model_hook(models, input_dir): logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, @@ -1440,6 +1465,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers @@ -1484,6 +1512,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0) + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + free_memory() + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1500,7 +1543,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): power=args.lr_power, ) - # Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`. if args.train_text_encoder: ( @@ -1607,8 +1649,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] + if args.train_text_encoder: + models_to_accumulate.extend([text_encoder_one, text_encoder_two]) with accelerator.accumulate(models_to_accumulate): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1639,8 +1682,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # Convert images to latent space - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1773,6 +1821,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -1793,15 +1843,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) - - del text_encoder_one, text_encoder_two, text_encoder_three - free_memory() + if not args.train_text_encoder: + del text_encoder_one, text_encoder_two, text_encoder_three + free_memory() # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(torch.float32) + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index c34024f478c1..455ba5a9293d 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -15,7 +15,6 @@ import argparse import copy -import gc import itertools import logging import math @@ -51,7 +50,7 @@ StableDiffusion3Pipeline, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory from diffusers.utils import ( check_min_version, is_wandb_available, @@ -119,7 +118,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="openrail++", + license="other", base_model=base_model, prompt=instance_prompt, model_description=model_description, @@ -164,7 +163,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference @@ -190,8 +189,7 @@ def log_validation( ) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() return images @@ -1065,8 +1063,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1386,9 +1383,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): del tokenizers, text_encoders # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection del text_encoder_one, text_encoder_two, text_encoder_three - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1708,6 +1703,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) + text_encoder_three.to(weight_dtype) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -1730,8 +1728,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two, text_encoder_three - torch.cuda.empty_cache() - gc.collect() + free_memory() # Save the lora layers accelerator.wait_for_everyone()