From 14a1b86fc7de53ff1dbf803f616cbb16ad530e45 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Thu, 19 Sep 2024 21:49:36 -0400 Subject: [PATCH 01/35] Several fixes to Flux ControlNet pipelines (#9472) * fix flux controlnet pipelines --------- Co-authored-by: yiyixuxu --- src/diffusers/pipelines/auto_pipeline.py | 11 ++++++++++- .../flux/pipeline_flux_controlnet.py | 19 +++++++++++-------- ...pipeline_flux_controlnet_image_to_image.py | 4 ++-- .../pipeline_flux_controlnet_inpainting.py | 4 ++-- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 39ceadb5acef..f6186da260ad 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -29,7 +29,14 @@ StableDiffusionXLControlNetPipeline, ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline -from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline +from .flux import ( + FluxControlNetImg2ImgPipeline, + FluxControlNetInpaintPipeline, + FluxControlNetPipeline, + FluxImg2ImgPipeline, + FluxInpaintPipeline, + FluxPipeline, +) from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -128,6 +135,7 @@ ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), ("flux", FluxImg2ImgPipeline), + ("flux-controlnet", FluxControlNetImg2ImgPipeline), ] ) @@ -143,6 +151,7 @@ ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("flux", FluxInpaintPipeline), + ("flux-controlnet", FluxControlNetInpaintPipeline), ] ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 481994903d3f..11b71b1cbece 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -729,7 +729,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] @@ -763,7 +763,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image_.shape[-2:] @@ -840,12 +840,10 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.tensor([guidance_scale], device=device) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None + guidance = ( + torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None # controlnet controlnet_block_samples, controlnet_single_block_samples = self.controlnet( @@ -863,6 +861,11 @@ def __call__( return_dict=False, ) + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 72803b180c34..deeb9e3f546a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -767,7 +767,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] @@ -798,7 +798,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image_.shape[-2:] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d43acdf38ea5..e763200155f6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -899,7 +899,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] @@ -933,7 +933,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image_.shape[-2:] From e5d0a328d681ea5e7d9ca81f102e757372c0ff10 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 21 Sep 2024 07:10:36 +0530 Subject: [PATCH 02/35] [refactor] LoRA tests (#9481) * refactor scheduler class usage * reorder to make tests more readable * remove pipeline specific checks and skip tests directly * rewrite denoiser conditions cleaner * bump tolerance for cog test --- tests/lora/test_lora_layers_cogvideox.py | 20 +- tests/lora/test_lora_layers_flux.py | 10 +- tests/lora/test_lora_layers_sd3.py | 18 +- tests/lora/utils.py | 383 ++++++----------------- 4 files changed, 142 insertions(+), 289 deletions(-) diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index 17b1cc8e764a..c141ebc96b3e 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -48,6 +48,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = CogVideoXPipeline scheduler_cls = CogVideoXDPMScheduler scheduler_kwargs = {"timestep_spacing": "trailing"} + scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler] transformer_kwargs = { "num_attention_heads": 4, @@ -126,8 +127,7 @@ def get_dummy_inputs(self, with_generator=True): @skip_mps def test_lora_fuse_nan(self): - scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler] - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -156,10 +156,22 @@ def test_lora_fuse_nan(self): self.assertTrue(np.isnan(out).all()) def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=5e-3) + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=5e-3) + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + @unittest.skip("Not supported in CogVideoX.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in CogVideoX.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in CogVideoX.") + def test_modify_padding_mode(self): + pass @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_partial_text_lora(self): diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index e849396f7c67..0c336ebc3cbf 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -47,7 +47,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = FluxPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_kwargs = {} - uses_flow_matching = True + scheduler_classes = [FlowMatchEulerDiscreteScheduler] transformer_kwargs = { "patch_size": 1, "in_channels": 4, @@ -154,6 +154,14 @@ def test_with_alpha_in_state_dict(self): ) self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) + @unittest.skip("Not supported in Flux.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Flux.") + def test_modify_padding_mode(self): + pass + @slow @require_torch_gpu diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 063ff4c8b05d..8f61c95c2fc8 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -34,7 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = StableDiffusion3Pipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} - uses_flow_matching = True + scheduler_classes = [FlowMatchEulerDiscreteScheduler] transformer_kwargs = { "sample_size": 32, "patch_size": 1, @@ -92,3 +92,19 @@ def test_sd3_lora(self): lora_filename = "lora_peft_format.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + @unittest.skip("Not supported in SD3.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in SD3.") + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): + pass + + @unittest.skip("Not supported in SD3.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in SD3.") + def test_modify_padding_mode(self): + pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index adf7cb24470f..939b749c286a 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -24,7 +24,6 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, - FlowMatchEulerDiscreteScheduler, LCMScheduler, UNet2DConditionModel, ) @@ -69,9 +68,10 @@ def check_if_lora_correctly_set(model) -> bool: @require_peft_backend class PeftLoraLoaderMixinTests: pipeline_class = None + scheduler_cls = None scheduler_kwargs = None - uses_flow_matching = False + scheduler_classes = [DDIMScheduler, LCMScheduler] has_two_text_encoders = False has_three_text_encoders = False @@ -205,13 +205,7 @@ def test_simple_inference(self): """ Tests a simple inference and makes sure it works as expected """ - # TODO(aryan): Some of the assumptions made here in many different tests are incorrect for CogVideoX. - # For example, we need to test with CogVideoXDDIMScheduler and CogVideoDPMScheduler instead of DDIMScheduler - # and LCMScheduler, which are not supported by it. - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -226,10 +220,7 @@ def test_simple_inference_with_text_lora(self): Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -260,17 +251,16 @@ def test_simple_inference_with_text_lora_and_scale(self): Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() + + # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: if possible_attention_kwargs in call_signature_keys: attention_kwargs_name = possible_attention_kwargs break assert attention_kwargs_name is not None - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -317,10 +307,7 @@ def test_simple_inference_with_text_lora_fused(self): Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -360,10 +347,7 @@ def test_simple_inference_with_text_lora_unloaded(self): Tests a simple inference with lora attached to text encoder, then unloads the lora weights and makes sure it works as expected """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -410,10 +394,7 @@ def test_simple_inference_with_text_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA. """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -490,10 +471,7 @@ def test_simple_inference_with_partial_text_lora(self): with different ranks and some adapters removed and makes sure it works as expected """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components(scheduler_cls) # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( @@ -555,10 +533,7 @@ def test_simple_inference_save_pretrained(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -609,10 +584,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -628,13 +600,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config) - else: - pipe.transformer.add_adapter(denoiser_lora_config) - - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -652,10 +620,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): else None ) - if self.unet_kwargs is not None: - denoiser_state_dict = get_peft_model_state_dict(pipe.unet) - else: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + denoiser_state_dict = get_peft_model_state_dict(denoiser) saving_kwargs = { "save_directory": tmpdirname, @@ -689,8 +654,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -708,9 +672,6 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): Tests a simple inference with lora attached on the text encoder + Unet + scale argument and makes sure it works as expected """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: if possible_attention_kwargs in call_signature_keys: @@ -718,7 +679,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): break assert attention_kwargs_name is not None - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -734,13 +695,9 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config) - else: - pipe.transformer.add_adapter(denoiser_lora_config) - - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -781,10 +738,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -800,13 +754,9 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config) - else: - pipe.transformer.add_adapter(denoiser_lora_config) - - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -823,8 +773,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -842,10 +791,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -861,12 +807,9 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config) - else: - pipe.transformer.add_adapter(denoiser_lora_config) - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -880,10 +823,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): self.assertFalse( check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" ) - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertFalse( - check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly unloaded in denoiser" - ) + self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -905,10 +845,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused( Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -921,13 +858,9 @@ def test_simple_inference_with_text_denoiser_lora_unfused( check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config) - else: - pipe.transformer.add_adapter(denoiser_lora_config) - - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -946,8 +879,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused( if "text_encoder" in self.pipeline_class._lora_loadable_modules: self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -966,10 +898,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -985,17 +914,10 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -1041,15 +963,9 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): def test_simple_inference_with_text_denoiser_block_scale(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches - one adapter and set differnt weights for different blocks (i.e. block lora) + one adapter and set different weights for different blocks (i.e. block lora) """ - if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "CogVideoXPipeline"]: - return - - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1059,14 +975,11 @@ def test_simple_inference_with_text_denoiser_block_scale(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -1109,13 +1022,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set differnt weights for different blocks (i.e. block lora) """ - if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": - return - - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1131,17 +1038,10 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -1193,8 +1093,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" - if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]: - return def updown_options(blocks_with_tf, layers_per_block, value): """ @@ -1266,10 +1164,9 @@ def all_possible_dict_opts(unet, value): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules @@ -1288,10 +1185,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set/delete them """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1307,18 +1201,10 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules @@ -1373,14 +1259,10 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"]) @@ -1397,10 +1279,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1416,17 +1295,10 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules @@ -1471,7 +1343,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): ) pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -1481,10 +1352,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): @skip_mps def test_lora_fuse_nan(self): - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1497,13 +1365,9 @@ def test_lora_fuse_nan(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") # corrupt one LoRA weight with `inf` values with torch.no_grad(): @@ -1520,7 +1384,6 @@ def test_lora_fuse_nan(self): # without we should not see an error, but every image will be black pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) - out = pipe("test", num_inference_steps=2, output_type="np")[0] self.assertTrue(np.isnan(out).all()) @@ -1530,10 +1393,7 @@ def test_get_adapters(self): Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1541,19 +1401,15 @@ def test_get_adapters(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") adapter_names = pipe.get_active_adapters() self.assertListEqual(adapter_names, ["adapter-1"]) pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") adapter_names = pipe.get_active_adapters() self.assertListEqual(adapter_names, ["adapter-2"]) @@ -1566,10 +1422,7 @@ def test_get_list_adapters(self): Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1583,12 +1436,9 @@ def test_get_list_adapters(self): if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - - if self.unet_kwargs is not None: dicts_to_be_checked.update({"unet": ["adapter-1"]}) else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") dicts_to_be_checked.update({"transformer": ["adapter-1"]}) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) @@ -1601,12 +1451,9 @@ def test_get_list_adapters(self): if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - - if self.unet_kwargs is not None: dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) @@ -1629,18 +1476,15 @@ def test_get_list_adapters(self): ) # 4. - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") - dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]}) else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]}) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) @@ -1653,10 +1497,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet and multi-adapter case """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1672,22 +1513,16 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" ) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") # Attach a second adapter if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") - else: - pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") + denoiser.add_adapter(denoiser_lora_config, "adapter-2") - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules @@ -1729,10 +1564,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( @require_peft_version_greater(peft_version="0.9.0") def test_simple_inference_with_dora(self): - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components( scheduler_cls, use_dora=True ) @@ -1745,14 +1577,11 @@ def test_simple_inference_with_dora(self): self.assertTrue(output_no_dora_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config) - else: - pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules @@ -1775,10 +1604,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1786,14 +1612,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config) - if self.unet_kwargs is not None: - pipe.unet.add_adapter(denoiser_lora_config) - else: - pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer - self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) @@ -1811,18 +1634,12 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): _ = pipe(**inputs, generator=torch.manual_seed(0))[0] def test_modify_padding_mode(self): - if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]: - return - def set_pad_mode(network, mode="circular"): for _, module in network.named_modules(): if isinstance(module, torch.nn.Conv2d): module.padding_mode = mode - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) - for scheduler_cls in scheduler_classes: + for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) From aa73072f1f7014635e3de916cbcf47858f4c37a0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 21 Sep 2024 07:44:47 +0530 Subject: [PATCH 03/35] [CI] fix nightly model tests (#9483) * check if default attn procs fix it. * print * print * replace * style./ * replace revision with variant. * replace with stable-diffusion-v1-5/stable-diffusion-inpainting. * replace with stable-diffusion-v1-5/stable-diffusion-v1-5. * fix --- tests/models/autoencoders/test_models_vae.py | 6 +++--- tests/models/unets/test_models_unet_2d_condition.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index 13e9bb1ba569..cdf441c0990d 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -651,12 +651,12 @@ def get_generator(self, seed=0): # fmt: off [ 33, - [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824], + [-0.1556, 0.9848, -0.0410, -0.0642, -0.2685, 0.8381, -0.2004, -0.0700], [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824], ], [ 47, - [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089], + [-0.2376, 0.1200, 0.1337, -0.4830, -0.2504, -0.0759, -0.0486, -0.4077], [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131], ], # fmt: on @@ -886,7 +886,7 @@ def get_generator(self, seed=0): # fmt: off [ 33, - [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078], + [-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205], [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824], ], [ diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index f91686925024..f354950b6075 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1210,11 +1210,11 @@ def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): return image def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): - revision = "fp16" if fp16 else None + variant = "fp16" if fp16 else None torch_dtype = torch.float16 if fp16 else torch.float32 model = UNet2DConditionModel.from_pretrained( - model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision + model_id, subfolder="unet", torch_dtype=torch_dtype, variant=variant ) model.to(torch_device).eval() @@ -1376,7 +1376,7 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): @require_torch_accelerator @skip_mps def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="Jiali/stable-diffusion-1.5") + model = self.get_unet_model(model_id="stable-diffusion-v1-5/stable-diffusion-v1-5") latents = self.get_latents(seed) encoder_hidden_states = self.get_encoder_hidden_states(seed) @@ -1404,7 +1404,7 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): ) @require_torch_accelerator_with_fp16 def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="Jiali/stable-diffusion-1.5", fp16=True) + model = self.get_unet_model(model_id="stable-diffusion-v1-5/stable-diffusion-v1-5", fp16=True) latents = self.get_latents(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) @@ -1433,7 +1433,7 @@ def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): @require_torch_accelerator @skip_mps def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="botp/stable-diffusion-v1-5-inpainting") + model = self.get_unet_model(model_id="stable-diffusion-v1-5/stable-diffusion-inpainting") latents = self.get_latents(seed, shape=(4, 9, 64, 64)) encoder_hidden_states = self.get_encoder_hidden_states(seed) @@ -1461,7 +1461,7 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): ) @require_torch_accelerator_with_fp16 def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="botp/stable-diffusion-v1-5-inpainting", fp16=True) + model = self.get_unet_model(model_id="stable-diffusion-v1-5/stable-diffusion-inpainting", fp16=True) latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) From ba5af5aebbac0cc18168076a18836f175753d1c7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 23 Sep 2024 11:27:05 +0530 Subject: [PATCH 04/35] [Cog] some minor fixes and nits (#9466) * fix positional arguments in check_inputs(). * add video and latetns to check_inputs(). * prep latents_in_channels. * quality * multiple fixes. * fix --- .../pipelines/cogvideo/pipeline_cogvideox.py | 26 +++++----- .../pipeline_cogvideox_image2video.py | 52 +++++++++---------- .../pipeline_cogvideox_video2video.py | 48 +++++++++-------- 3 files changed, 65 insertions(+), 61 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 02497e77edb7..82839ffd2c92 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -188,6 +188,9 @@ def __init__( self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -317,6 +320,12 @@ def encode_prompt( def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + shape = ( batch_size, (num_frames - 1) // self.vae_scale_factor_temporal + 1, @@ -324,11 +333,6 @@ def prepare_latents( height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -341,7 +345,7 @@ def prepare_latents( def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] - latents = 1 / self.vae.config.scaling_factor * latents + latents = 1 / self.vae_scaling_factor_image * latents frames = self.vae.decode(latents).sample return frames @@ -510,10 +514,10 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. num_frames (`int`, defaults to `48`): Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where @@ -587,8 +591,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index a1576be97977..afc11bce00d5 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -207,6 +207,9 @@ def __init__( self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -348,6 +351,12 @@ def prepare_latents( generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 shape = ( batch_size, @@ -357,12 +366,6 @@ def prepare_latents( width // self.vae_scale_factor_spatial, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - image = image.unsqueeze(2) # [B, C, F, H, W] if isinstance(generator, list): @@ -373,7 +376,7 @@ def prepare_latents( image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] - image_latents = self.vae.config.scaling_factor * image_latents + image_latents = self.vae_scaling_factor_image * image_latents padding_shape = ( batch_size, @@ -397,7 +400,7 @@ def prepare_latents( # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] - latents = 1 / self.vae.config.scaling_factor * latents + latents = 1 / self.vae_scaling_factor_image * latents frames = self.vae.decode(latents).sample return frames @@ -438,7 +441,6 @@ def check_inputs( width, negative_prompt, callback_on_step_end_tensor_inputs, - video=None, latents=None, prompt_embeds=None, negative_prompt_embeds=None, @@ -494,9 +496,6 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - if video is not None and latents is not None: - raise ValueError("Only one of `video` or `latents` should be provided") - # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections def fuse_qkv_projections(self) -> None: r"""Enables fused QKV projections.""" @@ -584,7 +583,7 @@ def __call__( Args: image (`PipelineImageInput`): - The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -592,10 +591,10 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. num_frames (`int`, defaults to `48`): Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where @@ -665,20 +664,19 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( - image, - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds, - negative_prompt_embeds, + image=image, + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) self._guidance_scale = guidance_scale self._interrupt = False diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 649199829cf4..35f8f2fa0508 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -204,12 +204,16 @@ def __init__( self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) + self.vae_scale_factor_spatial = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -351,6 +355,12 @@ def prepare_latents( latents: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None, ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) shape = ( @@ -361,12 +371,6 @@ def prepare_latents( width // self.vae_scale_factor_spatial, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if latents is None: if isinstance(generator, list): if len(generator) != batch_size: @@ -382,7 +386,7 @@ def prepare_latents( init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] - init_latents = self.vae.config.scaling_factor * init_latents + init_latents = self.vae_scaling_factor_image * init_latents noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.scheduler.add_noise(init_latents, noise, timestep) @@ -396,7 +400,7 @@ def prepare_latents( # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] - latents = 1 / self.vae.config.scaling_factor * latents + latents = 1 / self.vae_scaling_factor_image * latents frames = self.vae.decode(latents).sample return frames @@ -589,10 +593,10 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -658,20 +662,20 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, - height, - width, - strength, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds, - negative_prompt_embeds, + prompt=prompt, + height=height, + width=width, + strength=strength, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + video=video, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs From 14f6464bef677e47e1ff13a12f4ddd97e7f3e973 Mon Sep 17 00:00:00 2001 From: M Saqlain <118016760+saqlain2204@users.noreply.github.com> Date: Mon, 23 Sep 2024 20:35:50 +0530 Subject: [PATCH 05/35] [Tests] Reduce the model size in the lumina test (#8985) * Reduced model size for lumina-tests * Handled failing tests --- tests/pipelines/lumina/test_lumina_nextdit.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index d6aeb57b80a1..5fd0dbf06050 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -34,19 +34,19 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM def get_dummy_components(self): torch.manual_seed(0) transformer = LuminaNextDiT2DModel( - sample_size=16, + sample_size=4, patch_size=2, in_channels=4, - hidden_size=24, + hidden_size=4, num_layers=2, - num_attention_heads=3, + num_attention_heads=1, num_kv_heads=1, multiple_of=16, ffn_dim_multiplier=None, norm_eps=1e-5, learn_sigma=True, qk_norm=True, - cross_attention_dim=32, + cross_attention_dim=8, scaling_factor=1.0, ) torch.manual_seed(0) @@ -57,8 +57,8 @@ def get_dummy_components(self): torch.manual_seed(0) config = GemmaConfig( - head_dim=4, - hidden_size=32, + head_dim=2, + hidden_size=8, intermediate_size=37, num_attention_heads=4, num_hidden_layers=2, From 00f5b41862de8d6dfe91688462ccc2023e72c04e Mon Sep 17 00:00:00 2001 From: pibbo88 <81701354+pibbo88@users.noreply.github.com> Date: Tue, 24 Sep 2024 06:30:24 +0800 Subject: [PATCH 06/35] Fix the bug of sd3 controlnet training when using gradient checkpointing. (#9498) Fix the bug of sd3 controlnet training when using gradient_checkpointing. Refer to issue #9496 --- src/diffusers/models/controlnet_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index f19571dafb18..43b52a645a0d 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -336,7 +336,7 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, From 65f9439b569b1e7e2854cfa0274ee3f7f50b43a0 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 24 Sep 2024 00:12:51 +0100 Subject: [PATCH 07/35] [Schedulers] Add exponential sigmas / exponential noise schedule (#9499) * exponential sigmas * Apply suggestions from code review Co-authored-by: YiYi Xu * make style --------- Co-authored-by: YiYi Xu --- .../schedulers/scheduling_euler_discrete.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 46e0e6baef81..2b74558d3cb7 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -158,6 +158,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -186,6 +188,7 @@ def __init__( prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, timestep_spacing: str = "linspace", @@ -235,6 +238,7 @@ def __init__( self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas + self.use_exponential_sigmas = use_exponential_sigmas self._step_index = None self._begin_index = None @@ -332,6 +336,12 @@ def set_timesteps( raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.") if timesteps is not None and self.config.use_karras_sigmas: raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") + if timesteps is not None and self.config.use_exponential_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") + if self.config.use_exponential_sigmas and self.config.use_karras_sigmas: + raise ValueError( + "Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`" + ) if ( timesteps is not None and self.config.timestep_type == "continuous" @@ -396,6 +406,10 @@ def set_timesteps( sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 elif self.config.final_sigmas_type == "zero": @@ -468,6 +482,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26 + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps From 3e69e241f794e81d1cfab8ae29695ab384fefaea Mon Sep 17 00:00:00 2001 From: Seongbin Lim <58146755+sbinnee@users.noreply.github.com> Date: Tue, 24 Sep 2024 08:28:14 +0900 Subject: [PATCH 08/35] Allow DDPMPipeline half precision (#9222) Co-authored-by: YiYi Xu --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 093a3cdfe512..bb03a8d66758 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -101,10 +101,10 @@ def __call__( if self.device.type == "mps": # randn does not work reproducibly on mps - image = randn_tensor(image_shape, generator=generator) + image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype) image = image.to(self.device) else: - image = randn_tensor(image_shape, generator=generator, device=self.device) + image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) # set step values self.scheduler.set_timesteps(num_inference_steps) From 19547a57341bc4033b4f372d734c661a69e59311 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 24 Sep 2024 00:39:55 +0100 Subject: [PATCH 09/35] Add Noise Schedule/Schedule Type to Schedulers Overview documentation (#9504) * Add Noise Schedule/Schedule Type to Schedulers Overview docs * Update docs/source/en/api/schedulers/overview.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/schedulers/overview.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/en/api/schedulers/overview.md b/docs/source/en/api/schedulers/overview.md index 28db9f3f7aac..ea8fcb15afd6 100644 --- a/docs/source/en/api/schedulers/overview.md +++ b/docs/source/en/api/schedulers/overview.md @@ -45,6 +45,13 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso | N/A | [`DEISMultistepScheduler`] | | | N/A | [`UniPCMultistepScheduler`] | | +## Noise schedules and schedule types +| A1111/k-diffusion | 🤗 Diffusers | +|---------------------|----------------------------------------| +| Karras | init with `use_karras_sigmas=True` | +| sgm_uniform | init with `timestep_spacing="trailing"`| +| simple | init with `timestep_spacing="trailing"`| + All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers. ## SchedulerMixin From bab17789b5f956a5239df4bb82e0687bf917be31 Mon Sep 17 00:00:00 2001 From: captainzz <73270275+xduzhangjiayu@users.noreply.github.com> Date: Tue, 24 Sep 2024 07:40:44 +0800 Subject: [PATCH 10/35] fix bugs for sd3 controlnet training (#9489) Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- examples/controlnet/train_controlnet_sd3.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 052eb9d4bf76..4b255c501d99 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -31,7 +31,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version @@ -899,12 +899,13 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, + kwargs_handlers=[kwargs], ) # Disable AMP for MPS. From 2b5bc5be0b8939f622461c62dff77c812ef66da4 Mon Sep 17 00:00:00 2001 From: LukeLin <60426396+LukeLIN-web@users.noreply.github.com> Date: Mon, 23 Sep 2024 19:47:34 -0400 Subject: [PATCH 11/35] [Doc] Fix path and and also import imageio (#9506) * Fix bug * import imageio --- docs/source/en/api/pipelines/text_to_video_zero.md | 1 + docs/source/en/optimization/coreml.md | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/text_to_video_zero.md b/docs/source/en/api/pipelines/text_to_video_zero.md index 6f157c668785..c6bf30fed7af 100644 --- a/docs/source/en/api/pipelines/text_to_video_zero.md +++ b/docs/source/en/api/pipelines/text_to_video_zero.md @@ -40,6 +40,7 @@ To generate a video from prompt, run the following Python code: ```python import torch from diffusers import TextToVideoZeroPipeline +import imageio model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") diff --git a/docs/source/en/optimization/coreml.md b/docs/source/en/optimization/coreml.md index 49ff1e9c9356..d090ef0ed3ba 100644 --- a/docs/source/en/optimization/coreml.md +++ b/docs/source/en/optimization/coreml.md @@ -95,7 +95,7 @@ print(f"Model downloaded at {model_path}") Once you have downloaded a snapshot of the model, you can test it using Apple's Python script. ```shell -python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" -i models/coreml-stable-diffusion-v1-4_original_packages -o --compute-unit CPU_AND_GPU --seed 93 +python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" -i ./models/coreml-stable-diffusion-v1-4_original_packages/original/packages -o --compute-unit CPU_AND_GPU --seed 93 ``` Pass the path of the downloaded checkpoint with `-i` flag to the script. `--compute-unit` indicates the hardware you want to allow for inference. It must be one of the following options: `ALL`, `CPU_AND_GPU`, `CPU_ONLY`, `CPU_AND_NE`. You may also provide an optional output path, and a seed for reproducibility. From 28f9d84549c0b1d24ef00d69a4c723f3a11cffb6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 24 Sep 2024 09:42:11 +0530 Subject: [PATCH 12/35] [CI] allow faster downloads from the Hub in CI. (#9478) * allow faster downloads from the Hub in CI. * HF_HUB_ENABLE_HF_TRANSFER: 1 * empty * empty * remove ENV HF_HUB_ENABLE_HF_TRANSFER=1. * empty --- .github/workflows/benchmark.yml | 1 + .github/workflows/pr_tests.yml | 1 + .github/workflows/push_tests.yml | 1 + .github/workflows/push_tests_fast.yml | 1 + .github/workflows/push_tests_mps.yml | 1 + docker/diffusers-flax-cpu/Dockerfile | 3 ++- docker/diffusers-flax-tpu/Dockerfile | 3 ++- docker/diffusers-onnxruntime-cpu/Dockerfile | 3 ++- docker/diffusers-onnxruntime-cuda/Dockerfile | 3 ++- docker/diffusers-pytorch-compile-cuda/Dockerfile | 3 ++- docker/diffusers-pytorch-cpu/Dockerfile | 3 ++- docker/diffusers-pytorch-cuda/Dockerfile | 3 ++- docker/diffusers-pytorch-xformers-cuda/Dockerfile | 3 ++- 13 files changed, 21 insertions(+), 8 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 2e7e82f056ed..d311c1c73f11 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -7,6 +7,7 @@ on: env: DIFFUSERS_IS_CI: yes + HF_HUB_ENABLE_HF_TRANSFER: 1 HF_HOME: /mnt/cache OMP_NUM_THREADS: 8 MKL_NUM_THREADS: 8 diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index d40270ab46fd..6bb67976b170 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -22,6 +22,7 @@ concurrency: env: DIFFUSERS_IS_CI: yes + HF_HUB_ENABLE_HF_TRANSFER: 1 OMP_NUM_THREADS: 4 MKL_NUM_THREADS: 4 PYTEST_TIMEOUT: 60 diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index b8214da328ff..f07e6cda0d59 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -14,6 +14,7 @@ env: DIFFUSERS_IS_CI: yes OMP_NUM_THREADS: 8 MKL_NUM_THREADS: 8 + HF_HUB_ENABLE_HF_TRANSFER: 1 PYTEST_TIMEOUT: 600 PIPELINE_USAGE_CUTOFF: 50000 diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index 4e3a01fdd97f..e8a73446de73 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -18,6 +18,7 @@ env: HF_HOME: /mnt/cache OMP_NUM_THREADS: 8 MKL_NUM_THREADS: 8 + HF_HUB_ENABLE_HF_TRANSFER: 1 PYTEST_TIMEOUT: 600 RUN_SLOW: no diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml index f261b6c00d1c..8d521074a08f 100644 --- a/.github/workflows/push_tests_mps.yml +++ b/.github/workflows/push_tests_mps.yml @@ -13,6 +13,7 @@ env: HF_HOME: /mnt/cache OMP_NUM_THREADS: 8 MKL_NUM_THREADS: 8 + HF_HUB_ENABLE_HF_TRANSFER: 1 PYTEST_TIMEOUT: 600 RUN_SLOW: no diff --git a/docker/diffusers-flax-cpu/Dockerfile b/docker/diffusers-flax-cpu/Dockerfile index 86a49171d290..051008aa9a2e 100644 --- a/docker/diffusers-flax-cpu/Dockerfile +++ b/docker/diffusers-flax-cpu/Dockerfile @@ -43,6 +43,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ numpy==1.26.4 \ scipy \ tensorboard \ - transformers + transformers \ + hf_transfer CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-flax-tpu/Dockerfile b/docker/diffusers-flax-tpu/Dockerfile index b40cd55a1c16..405f068923b7 100644 --- a/docker/diffusers-flax-tpu/Dockerfile +++ b/docker/diffusers-flax-tpu/Dockerfile @@ -45,6 +45,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ numpy==1.26.4 \ scipy \ tensorboard \ - transformers + transformers \ + hf_transfer CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-onnxruntime-cpu/Dockerfile b/docker/diffusers-onnxruntime-cpu/Dockerfile index a5a6e98605cb..6f4b13e8a9ba 100644 --- a/docker/diffusers-onnxruntime-cpu/Dockerfile +++ b/docker/diffusers-onnxruntime-cpu/Dockerfile @@ -43,6 +43,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ numpy==1.26.4 \ scipy \ tensorboard \ - transformers + transformers \ + hf_transfer CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-onnxruntime-cuda/Dockerfile b/docker/diffusers-onnxruntime-cuda/Dockerfile index 3364698fe945..bd1d871033c9 100644 --- a/docker/diffusers-onnxruntime-cuda/Dockerfile +++ b/docker/diffusers-onnxruntime-cuda/Dockerfile @@ -44,6 +44,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ numpy==1.26.4 \ scipy \ tensorboard \ - transformers + transformers \ + hf_transfer CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-pytorch-compile-cuda/Dockerfile b/docker/diffusers-pytorch-compile-cuda/Dockerfile index a5454328b851..cb4a9c0f9896 100644 --- a/docker/diffusers-pytorch-compile-cuda/Dockerfile +++ b/docker/diffusers-pytorch-compile-cuda/Dockerfile @@ -44,6 +44,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ numpy==1.26.4 \ scipy \ tensorboard \ - transformers + transformers \ + hf_transfer CMD ["/bin/bash"] diff --git a/docker/diffusers-pytorch-cpu/Dockerfile b/docker/diffusers-pytorch-cpu/Dockerfile index 910765bb0b9c..8d98c52598d2 100644 --- a/docker/diffusers-pytorch-cpu/Dockerfile +++ b/docker/diffusers-pytorch-cpu/Dockerfile @@ -44,6 +44,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ numpy==1.26.4 \ scipy \ tensorboard \ - transformers matplotlib + transformers matplotlib \ + hf_transfer CMD ["/bin/bash"] diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile index 8b5439ffb6c6..695f5ed08dc5 100644 --- a/docker/diffusers-pytorch-cuda/Dockerfile +++ b/docker/diffusers-pytorch-cuda/Dockerfile @@ -45,6 +45,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ scipy \ tensorboard \ transformers \ - pytorch-lightning + pytorch-lightning \ + hf_transfer CMD ["/bin/bash"] diff --git a/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/docker/diffusers-pytorch-xformers-cuda/Dockerfile index 7a3408c48624..1693eb293024 100644 --- a/docker/diffusers-pytorch-xformers-cuda/Dockerfile +++ b/docker/diffusers-pytorch-xformers-cuda/Dockerfile @@ -45,6 +45,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ scipy \ tensorboard \ transformers \ - xformers + xformers \ + hf_transfer CMD ["/bin/bash"] From bac8a2412d4cb168116fba2fd8143f5dad44c832 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 24 Sep 2024 13:36:53 -1000 Subject: [PATCH 13/35] a few fix for SingleFile tests (#9522) * update sd15 repo * update more --- src/diffusers/loaders/single_file_utils.py | 6 ++- tests/lora/test_lora_layers_sd.py | 54 ++++++++++--------- tests/models/autoencoders/test_models_vae.py | 8 +-- tests/pipelines/controlnet/test_controlnet.py | 30 ++++++----- .../controlnet/test_controlnet_img2img.py | 2 +- .../controlnet/test_controlnet_inpaint.py | 2 +- .../controlnet/test_flax_controlnet.py | 4 +- .../test_ip_adapter_stable_diffusion.py | 37 ++++++++++--- .../test_ledits_pp_stable_diffusion.py | 2 +- tests/pipelines/pag/test_pag_sd.py | 2 +- .../test_semantic_diffusion.py | 10 ++-- .../test_onnx_stable_diffusion.py | 12 ++--- .../test_onnx_stable_diffusion_img2img.py | 4 +- .../stable_diffusion/test_stable_diffusion.py | 32 ++++++----- .../test_stable_diffusion_img2img.py | 12 ++--- .../test_stable_diffusion_inpaint.py | 8 ++- .../test_stable_diffusion_adapter.py | 8 +-- .../test_safe_diffusion.py | 10 ++-- tests/pipelines/test_pipelines_auto.py | 4 +- .../test_text_to_video_zero.py | 2 +- ...iffusion_controlnet_img2img_single_file.py | 6 ++- ...iffusion_controlnet_inpaint_single_file.py | 2 +- ...stable_diffusion_controlnet_single_file.py | 6 ++- ...st_stable_diffusion_img2img_single_file.py | 6 ++- ...st_stable_diffusion_inpaint_single_file.py | 2 +- .../test_stable_diffusion_single_file.py | 6 ++- 26 files changed, 165 insertions(+), 112 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d620c15e8377..236fbd0c2295 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -14,6 +14,7 @@ # limitations under the License. """Conversion script for the Stable Diffusion checkpoints.""" +import copy import os import re from contextlib import nullcontext @@ -91,11 +92,11 @@ "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"}, "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"}, "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"}, - "inpainting": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8-inpainting"}, + "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"}, "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"}, "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"}, "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"}, - "v1": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8"}, + "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"}, "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, "stable_cascade_stage_b_lite": { "pretrained_model_name_or_path": "stabilityai/stable-cascade", @@ -541,6 +542,7 @@ def infer_diffusers_model_type(checkpoint): def fetch_diffusers_config(checkpoint): model_type = infer_diffusers_model_type(checkpoint) model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type] + model_path = copy.deepcopy(model_path) return model_path diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 5d79bb0c50bc..50187e50a912 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -102,7 +102,7 @@ def tearDown(self): @slow @require_torch_gpu def test_integration_move_lora_cpu(self): - path = "Jiali/stable-diffusion-1.5" + path = "stable-diffusion-v1-5/stable-diffusion-v1-5" lora_id = "takuma104/lora-test-text-encoder-lora-target" pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) @@ -161,7 +161,7 @@ def test_integration_move_lora_cpu(self): def test_integration_move_lora_dora_cpu(self): from peft import LoraConfig - path = "Jiali/stable-diffusion-1.5" + path = "stable-diffusion-v1-5/stable-diffusion-v1-5" unet_lora_config = LoraConfig( init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], @@ -221,7 +221,7 @@ def tearDown(self): torch.cuda.empty_cache() def test_integration_logits_with_scale(self): - path = "Jiali/stable-diffusion-1.5" + path = "stable-diffusion-v1-5/stable-diffusion-v1-5" lora_id = "takuma104/lora-test-text-encoder-lora-target" pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) @@ -253,7 +253,7 @@ def test_integration_logits_with_scale(self): release_memory(pipe) def test_integration_logits_no_scale(self): - path = "Jiali/stable-diffusion-1.5" + path = "stable-diffusion-v1-5/stable-diffusion-v1-5" lora_id = "takuma104/lora-test-text-encoder-lora-target" pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) @@ -284,7 +284,7 @@ def test_dreambooth_old_format(self): lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example" - base_model_id = "Jiali/stable-diffusion-1.5" + base_model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) pipe = pipe.to(torch_device) @@ -308,7 +308,7 @@ def test_dreambooth_text_encoder_new_format(self): lora_model_id = "hf-internal-testing/lora-trained" - base_model_id = "Jiali/stable-diffusion-1.5" + base_model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) pipe = pipe.to(torch_device) @@ -419,9 +419,9 @@ def test_a1111_with_sequential_cpu_offload(self): def test_kohya_sd_v15_with_higher_dimensions(self): generator = torch.Generator().manual_seed(0) - pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5", safety_checker=None).to( - torch_device - ) + pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None + ).to(torch_device) lora_model_id = "hf-internal-testing/urushisato-lora" lora_filename = "urushisato_v15.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) @@ -444,7 +444,7 @@ def test_vanilla_funetuning(self): lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4" - base_model_id = "Jiali/stable-diffusion-1.5" + base_model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) pipe = pipe.to(torch_device) @@ -467,9 +467,9 @@ def test_unload_kohya_lora(self): prompt = "masterpiece, best quality, mountain" num_inference_steps = 2 - pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5", safety_checker=None).to( - torch_device - ) + pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None + ).to(torch_device) initial_images = pipe( prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps ).images @@ -505,9 +505,9 @@ def test_load_unload_load_kohya_lora(self): prompt = "masterpiece, best quality, mountain" num_inference_steps = 2 - pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5", safety_checker=None).to( - torch_device - ) + pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None + ).to(torch_device) initial_images = pipe( prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps ).images @@ -547,9 +547,9 @@ def test_load_unload_load_kohya_lora(self): def test_not_empty_state_dict(self): # Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again - pipe = AutoPipelineForText2Image.from_pretrained("Jiali/stable-diffusion-1.5", torch_dtype=torch.float16).to( - torch_device - ) + pipe = AutoPipelineForText2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16 + ).to(torch_device) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors") @@ -561,9 +561,9 @@ def test_not_empty_state_dict(self): def test_load_unload_load_state_dict(self): # Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again - pipe = AutoPipelineForText2Image.from_pretrained("Jiali/stable-diffusion-1.5", torch_dtype=torch.float16).to( - torch_device - ) + pipe = AutoPipelineForText2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16 + ).to(torch_device) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors") @@ -580,7 +580,9 @@ def test_load_unload_load_state_dict(self): release_memory(pipe) def test_sdv1_5_lcm_lora(self): - pipe = DiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5", torch_dtype=torch.float16) + pipe = DiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16 + ) pipe.to(torch_device) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) @@ -608,7 +610,9 @@ def test_sdv1_5_lcm_lora(self): release_memory(pipe) def test_sdv1_5_lcm_lora_img2img(self): - pipe = AutoPipelineForImage2Image.from_pretrained("Jiali/stable-diffusion-1.5", torch_dtype=torch.float16) + pipe = AutoPipelineForImage2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16 + ) pipe.to(torch_device) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) @@ -649,7 +653,7 @@ def test_sd_load_civitai_empty_network_alpha(self): This test simply checks that loading a LoRA with an empty network alpha works fine See: https://github.com/huggingface/diffusers/issues/5606 """ - pipeline = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") pipeline.enable_sequential_cpu_offload() civitai_path = hf_hub_download("ybelkada/test-ahi-civitai", "ahi_lora_weights.safetensors") pipeline.load_lora_weights(civitai_path, adapter_name="ahri") diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index cdf441c0990d..0188f9121ae0 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -1051,7 +1051,9 @@ def test_encode_decode(self): def test_sd(self): vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update - pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5", vae=vae, safety_checker=None) + pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, safety_checker=None + ) pipe.to(torch_device) out = pipe( @@ -1099,7 +1101,7 @@ def test_sd_f16(self): "openai/consistency-decoder", torch_dtype=torch.float16 ) # TODO - update pipe = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None, @@ -1124,7 +1126,7 @@ def test_sd_f16(self): def test_vae_tiling(self): vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16) pipe = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", vae=vae, safety_checker=None, torch_dtype=torch.float16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16 ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 302712dbfd0d..b12655d989d4 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -73,7 +73,7 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.to("cuda") pipe.set_progress_bar_config(disable=None) @@ -715,7 +715,7 @@ def test_canny(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -742,7 +742,7 @@ def test_depth(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -769,7 +769,7 @@ def test_hed(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-hed") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -796,7 +796,7 @@ def test_mlsd(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-mlsd") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -823,7 +823,7 @@ def test_normal(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-normal") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -850,7 +850,7 @@ def test_openpose(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -877,7 +877,7 @@ def test_scribble(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-scribble") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -904,7 +904,7 @@ def test_seg(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -935,7 +935,7 @@ def test_sequential_cpu_offloading(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing() @@ -961,7 +961,7 @@ def test_canny_guess_mode(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -993,7 +993,7 @@ def test_canny_guess_mode_euler(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() @@ -1035,7 +1035,7 @@ def test_v11_shuffle_global_pool_conditions(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -1081,7 +1081,9 @@ def test_pose_and_canny(self): controlnet_pose = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose") pipe = StableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=[controlnet_pose, controlnet_canny] + "stable-diffusion-v1-5/stable-diffusion-v1-5", + safety_checker=None, + controlnet=[controlnet_pose, controlnet_canny], ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 2b22b3e5a76d..7c4ae716b37d 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -407,7 +407,7 @@ def test_canny(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index eba493c20588..e49106334c2e 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -504,7 +504,7 @@ def test_inpaint(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint") pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, controlnet=controlnet + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() diff --git a/tests/pipelines/controlnet/test_flax_controlnet.py b/tests/pipelines/controlnet/test_flax_controlnet.py index 6ab66ccb6139..bf5564e810ef 100644 --- a/tests/pipelines/controlnet/test_flax_controlnet.py +++ b/tests/pipelines/controlnet/test_flax_controlnet.py @@ -41,7 +41,7 @@ def test_canny(self): "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16 ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 ) params["controlnet"] = controlnet_params @@ -86,7 +86,7 @@ def test_pose(self): "lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16 ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 ) params["controlnet"] = controlnet_params diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 741639e2d09e..a8180a3bc27f 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -170,7 +170,10 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): def test_text_to_image(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + "stable-diffusion-v1-5/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, ) pipeline.to(torch_device) pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") @@ -200,7 +203,10 @@ def test_text_to_image(self): def test_image_to_image(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + "stable-diffusion-v1-5/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, ) pipeline.to(torch_device) pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") @@ -232,7 +238,10 @@ def test_image_to_image(self): def test_inpainting(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionInpaintPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + "stable-diffusion-v1-5/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, ) pipeline.to(torch_device) pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") @@ -260,7 +269,10 @@ def test_inpainting(self): def test_text_to_image_model_cpu_offload(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + "stable-diffusion-v1-5/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, ) pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") pipeline.to(torch_device) @@ -287,7 +299,10 @@ def test_text_to_image_model_cpu_offload(self): def test_text_to_image_full_face(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + "stable-diffusion-v1-5/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, ) pipeline.to(torch_device) pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin") @@ -304,7 +319,10 @@ def test_text_to_image_full_face(self): def test_unload(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + "stable-diffusion-v1-5/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, ) before_processors = [attn_proc.__class__ for attn_proc in pipeline.unet.attn_processors.values()] pipeline.to(torch_device) @@ -323,7 +341,10 @@ def test_unload(self): def test_multi(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + "stable-diffusion-v1-5/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, ) pipeline.to(torch_device) pipeline.load_ip_adapter( @@ -343,7 +364,7 @@ def test_multi(self): def test_text_to_image_face_id(self): pipeline = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, torch_dtype=self.dtype + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype ) pipeline.to(torch_device) pipeline.load_ip_adapter( diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py index 12742def67f8..effea2619749 100644 --- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py +++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py @@ -224,7 +224,7 @@ def setUpClass(cls): def test_ledits_pp_editing(self): pipe = LEditsPPPipelineStableDiffusion.from_pretrained( - "Jiali/stable-diffusion-1.5", safety_checker=None, torch_dtype=torch.float16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, torch_dtype=torch.float16 ) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py index 6a644e02f5e8..3979bb170e0b 100644 --- a/tests/pipelines/pag/test_pag_sd.py +++ b/tests/pipelines/pag/test_pag_sd.py @@ -283,7 +283,7 @@ def test_pag_inference(self): @require_torch_gpu class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase): pipeline_class = StableDiffusionPAGPipeline - repo_id = "Jiali/stable-diffusion-1.5" + repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" def setUp(self): super().setUp() diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py index 3a9d3815e72d..990c389a9c5f 100644 --- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py +++ b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py @@ -287,7 +287,7 @@ def tearDown(self): def test_positive_guidance(self): torch_device = "cuda" - pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -370,7 +370,7 @@ def test_positive_guidance(self): def test_negative_guidance(self): torch_device = "cuda" - pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -453,7 +453,7 @@ def test_negative_guidance(self): def test_multi_cond_guidance(self): torch_device = "cuda" - pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -536,7 +536,9 @@ def test_multi_cond_guidance(self): def test_guidance_fp16(self): torch_device = "cuda" - pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5", torch_dtype=torch.float16) + pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16 + ) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py index 24ea6e07280e..f7036dee47f0 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py @@ -250,10 +250,10 @@ def test_inference_default_pndm(self): def test_inference_ddim(self): ddim_scheduler = DDIMScheduler.from_pretrained( - "Jiali/stable-diffusion-1.5", subfolder="scheduler", revision="onnx" + "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" ) sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", + "stable-diffusion-v1-5/stable-diffusion-v1-5", revision="onnx", scheduler=ddim_scheduler, safety_checker=None, @@ -276,10 +276,10 @@ def test_inference_ddim(self): def test_inference_k_lms(self): lms_scheduler = LMSDiscreteScheduler.from_pretrained( - "Jiali/stable-diffusion-1.5", subfolder="scheduler", revision="onnx" + "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" ) sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", + "stable-diffusion-v1-5/stable-diffusion-v1-5", revision="onnx", scheduler=lms_scheduler, safety_checker=None, @@ -327,7 +327,7 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: test_callback_fn.has_been_called = False pipe = OnnxStableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", + "stable-diffusion-v1-5/stable-diffusion-v1-5", revision="onnx", safety_checker=None, feature_extractor=None, @@ -352,7 +352,7 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: def test_stable_diffusion_no_safety_checker(self): pipe = OnnxStableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", + "stable-diffusion-v1-5/stable-diffusion-v1-5", revision="onnx", safety_checker=None, feature_extractor=None, diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py index 086354a2a649..c73ed0f6afe8 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py @@ -210,10 +210,10 @@ def test_inference_k_lms(self): ) init_image = init_image.resize((768, 512)) lms_scheduler = LMSDiscreteScheduler.from_pretrained( - "Jiali/stable-diffusion-1.5", subfolder="scheduler", revision="onnx" + "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" ) pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", + "stable-diffusion-v1-5/stable-diffusion-v1-5", revision="onnx", scheduler=lms_scheduler, safety_checker=None, diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 5bb13fac9b78..f37d598c8387 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -1332,7 +1332,7 @@ def tearDown(self): def test_download_from_hub(self): ckpt_paths = [ - "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors", + "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors", ] @@ -1346,8 +1346,10 @@ def test_download_from_hub(self): assert image_out.shape == (512, 512, 3) def test_download_local(self): - ckpt_filename = hf_hub_download("Jiali/stable-diffusion-1.5", filename="v1-5-pruned-emaonly.safetensors") - config_filename = hf_hub_download("Jiali/stable-diffusion-1.5", filename="v1-inference.yaml") + ckpt_filename = hf_hub_download( + "stable-diffusion-v1-5/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.safetensors" + ) + config_filename = hf_hub_download("stable-diffusion-v1-5/stable-diffusion-v1-5", filename="v1-inference.yaml") pipe = StableDiffusionPipeline.from_single_file( ckpt_filename, config_files={"v1": config_filename}, torch_dtype=torch.float16 @@ -1402,7 +1404,9 @@ def test_stable_diffusion_1_4_pndm(self): assert max_diff < 1e-3 def test_stable_diffusion_1_5_pndm(self): - sd_pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5").to(torch_device) + sd_pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5").to( + torch_device + ) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_inputs(torch_device) @@ -1483,9 +1487,9 @@ def get_inputs(self, generator_device="cpu", seed=0): return inputs def get_pipeline_output_without_device_map(self): - sd_pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5", torch_dtype=torch.float16).to( - torch_device - ) + sd_pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16 + ).to(torch_device) sd_pipe.set_progress_bar_config(disable=True) inputs = self.get_inputs() no_device_map_image = sd_pipe(**inputs).images @@ -1498,7 +1502,7 @@ def test_forward_pass_balanced_device_map(self): no_device_map_image = self.get_pipeline_output_without_device_map() sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", device_map="balanced", torch_dtype=torch.float16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16 ) sd_pipe_with_device_map.set_progress_bar_config(disable=True) inputs = self.get_inputs() @@ -1509,7 +1513,7 @@ def test_forward_pass_balanced_device_map(self): def test_components_put_in_right_devices(self): sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", device_map="balanced", torch_dtype=torch.float16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16 ) assert len(set(sd_pipe_with_device_map.hf_device_map.values())) >= 2 @@ -1518,7 +1522,7 @@ def test_max_memory(self): no_device_map_image = self.get_pipeline_output_without_device_map() sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", + "stable-diffusion-v1-5/stable-diffusion-v1-5", device_map="balanced", max_memory={0: "1GB", 1: "1GB"}, torch_dtype=torch.float16, @@ -1532,7 +1536,7 @@ def test_max_memory(self): def test_reset_device_map(self): sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", device_map="balanced", torch_dtype=torch.float16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16 ) sd_pipe_with_device_map.reset_device_map() @@ -1544,7 +1548,7 @@ def test_reset_device_map(self): def test_reset_device_map_to(self): sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", device_map="balanced", torch_dtype=torch.float16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16 ) sd_pipe_with_device_map.reset_device_map() @@ -1556,7 +1560,7 @@ def test_reset_device_map_to(self): def test_reset_device_map_enable_model_cpu_offload(self): sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", device_map="balanced", torch_dtype=torch.float16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16 ) sd_pipe_with_device_map.reset_device_map() @@ -1568,7 +1572,7 @@ def test_reset_device_map_enable_model_cpu_offload(self): def test_reset_device_map_enable_sequential_cpu_offload(self): sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained( - "Jiali/stable-diffusion-1.5", device_map="balanced", torch_dtype=torch.float16 + "stable-diffusion-v1-5/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16 ) sd_pipe_with_device_map.reset_device_map() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index e6de84781f26..7ba0bb5a4a5d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -566,7 +566,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): assert module.device == torch.device("cpu") def test_img2img_2nd_order(self): - sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") sd_pipe.scheduler = HeunDiscreteScheduler.from_config(sd_pipe.scheduler.config) sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -630,7 +630,7 @@ def test_stable_diffusion_img2img_pipeline_multiple_of_8(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 def test_img2img_safety_checker_works(self): - sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -686,7 +686,7 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0 return inputs def test_img2img_pndm(self): - sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -701,7 +701,7 @@ def test_img2img_pndm(self): assert max_diff < 1e-3 def test_img2img_ddim(self): - sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config) sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -717,7 +717,7 @@ def test_img2img_ddim(self): assert max_diff < 1e-3 def test_img2img_lms(self): - sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config) sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -733,7 +733,7 @@ def test_img2img_lms(self): assert max_diff < 1e-3 def test_img2img_dpm(self): - sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 28da97be9362..ff04ea2cfc5d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -767,7 +767,9 @@ def test_stable_diffusion_inpaint_strength_test(self): assert np.abs(expected_slice - image_slice).max() < 1e-3 def test_stable_diffusion_simple_inpaint_ddim(self): - pipe = StableDiffusionInpaintPipeline.from_pretrained("Jiali/stable-diffusion-1.5", safety_checker=None) + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None + ) pipe.unet.set_default_attn_processor() pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -973,7 +975,9 @@ def test_stable_diffusion_inpaint_strength_test(self): def test_stable_diffusion_simple_inpaint_ddim(self): vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5") - pipe = StableDiffusionInpaintPipeline.from_pretrained("Jiali/stable-diffusion-1.5", safety_checker=None) + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None + ) pipe.vae = vae pipe.unet.set_default_attn_processor() pipe.to(torch_device) diff --git a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py index 8e0841f064e0..2a1e691e9e8f 100644 --- a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py +++ b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py @@ -609,7 +609,7 @@ def tearDown(self): def test_stable_diffusion_adapter_depth_sd_v15(self): adapter_model = "TencentARC/t2iadapter_depth_sd15v2" - sd_model = "Jiali/stable-diffusion-1.5" + sd_model = "stable-diffusion-v1-5/stable-diffusion-v1-5" prompt = "desk" image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png" input_channels = 3 @@ -636,7 +636,7 @@ def test_stable_diffusion_adapter_depth_sd_v15(self): def test_stable_diffusion_adapter_zoedepth_sd_v15(self): adapter_model = "TencentARC/t2iadapter_zoedepth_sd15v1" - sd_model = "Jiali/stable-diffusion-1.5" + sd_model = "stable-diffusion-v1-5/stable-diffusion-v1-5" prompt = "motorcycle" image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motorcycle.png" input_channels = 3 @@ -660,7 +660,7 @@ def test_stable_diffusion_adapter_zoedepth_sd_v15(self): def test_stable_diffusion_adapter_canny_sd_v15(self): adapter_model = "TencentARC/t2iadapter_canny_sd15v2" - sd_model = "Jiali/stable-diffusion-1.5" + sd_model = "stable-diffusion-v1-5/stable-diffusion-v1-5" prompt = "toy" image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png" input_channels = 1 @@ -688,7 +688,7 @@ def test_stable_diffusion_adapter_canny_sd_v15(self): def test_stable_diffusion_adapter_sketch_sd15(self): adapter_model = "TencentARC/t2iadapter_sketch_sd15v2" - sd_model = "Jiali/stable-diffusion-1.5" + sd_model = "stable-diffusion-v1-5/stable-diffusion-v1-5" prompt = "cat" image_url = ( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png" diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py index cfaaa0914d17..ccb20a1c218e 100644 --- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py +++ b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py @@ -277,7 +277,9 @@ def tearDown(self): torch.cuda.empty_cache() def test_harm_safe_stable_diffusion(self): - sd_pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5", safety_checker=None) + sd_pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None + ) sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -338,7 +340,9 @@ def test_harm_safe_stable_diffusion(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_nudity_safe_stable_diffusion(self): - sd_pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5", safety_checker=None) + sd_pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None + ) sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -392,7 +396,7 @@ def test_nudity_safe_stable_diffusion(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_nudity_safetychecker_safe_stable_diffusion(self): - sd_pipe = StableDiffusionPipeline.from_pretrained("Jiali/stable-diffusion-1.5") + sd_pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py index b899cf240d52..561a9011c6ae 100644 --- a/tests/pipelines/test_pipelines_auto.py +++ b/tests/pipelines/test_pipelines_auto.py @@ -40,7 +40,7 @@ PRETRAINED_MODEL_REPO_MAPPING = OrderedDict( [ - ("stable-diffusion", "Jiali/stable-diffusion-1.5"), + ("stable-diffusion", "stable-diffusion-v1-5/stable-diffusion-v1-5"), ("if", "DeepFloyd/IF-I-XL-v1.0"), ("kandinsky", "kandinsky-community/kandinsky-2-1"), ("kandinsky22", "kandinsky-community/kandinsky-2-2-decoder"), @@ -539,7 +539,7 @@ def test_from_pipe_consistent(self): def test_controlnet(self): # test from_pretrained - model_repo = "Jiali/stable-diffusion-1.5" + model_repo = "stable-diffusion-v1-5/stable-diffusion-v1-5" controlnet_repo = "lllyasviel/sd-controlnet-canny" controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=torch.float16) diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py index 9038e3b0100f..f1bf6ee52206 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py @@ -40,7 +40,7 @@ def tearDown(self): torch.cuda.empty_cache() def test_full_model(self): - model_id = "Jiali/stable-diffusion-1.5" + model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) generator = torch.Generator(device="cuda").manual_seed(0) diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py index 5127b9e745d8..332bcfbe03b6 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py @@ -30,11 +30,13 @@ @require_torch_gpu class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetPipeline - ckpt_path = "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = ( + "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ) original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "Jiali/stable-diffusion-1.5" + repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" def setUp(self): super().setUp() diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py index 9d6576078a9c..c0d70123b286 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py @@ -31,7 +31,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC pipeline_class = StableDiffusionControlNetInpaintPipeline ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt" original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml" - repo_id = "botp/stable-diffusion-v1-5-inpainting" + repo_id = "stable-diffusion-v1-5/stable-diffusion-inpainting" def setUp(self): super().setUp() diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py index 13d64dab77a1..3b5cf910b080 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py @@ -29,11 +29,13 @@ @require_torch_gpu class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetPipeline - ckpt_path = "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = ( + "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ) original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "Jiali/stable-diffusion-1.5" + repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" def setUp(self): super().setUp() diff --git a/tests/single_file/test_stable_diffusion_img2img_single_file.py b/tests/single_file/test_stable_diffusion_img2img_single_file.py index fd99c4dede2f..04f36f255014 100644 --- a/tests/single_file/test_stable_diffusion_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_img2img_single_file.py @@ -23,11 +23,13 @@ @require_torch_gpu class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionImg2ImgPipeline - ckpt_path = "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = ( + "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ) original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "Jiali/stable-diffusion-1.5" + repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" def setUp(self): super().setUp() diff --git a/tests/single_file/test_stable_diffusion_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_inpaint_single_file.py index ba9583639b98..5c6734a9a33e 100644 --- a/tests/single_file/test_stable_diffusion_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_inpaint_single_file.py @@ -63,7 +63,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self): def test_single_file_loading_4_channel_unet(self): # Test loading single file inpaint with a 4 channel UNet - ckpt_path = "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" pipe = self.pipeline_class.from_single_file(ckpt_path) assert pipe.unet.config.in_channels == 4 diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py index d40af28b2407..e46e87e18c18 100644 --- a/tests/single_file/test_stable_diffusion_single_file.py +++ b/tests/single_file/test_stable_diffusion_single_file.py @@ -26,11 +26,13 @@ @require_torch_gpu class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionPipeline - ckpt_path = "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = ( + "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ) original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "Jiali/stable-diffusion-1.5" + repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" def setUp(self): super().setUp() From b52684c3edfa4aa33d72add90385c7b76c968b24 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 25 Sep 2024 01:50:12 +0100 Subject: [PATCH 14/35] Add exponential sigmas to other schedulers and update docs (#9518) --- docs/source/en/api/schedulers/overview.md | 11 ++++--- .../schedulers/scheduling_deis_multistep.py | 30 +++++++++++++++++ .../scheduling_dpmsolver_multistep.py | 32 +++++++++++++++++++ .../scheduling_dpmsolver_multistep_inverse.py | 31 ++++++++++++++++++ .../schedulers/scheduling_dpmsolver_sde.py | 30 +++++++++++++++++ .../scheduling_dpmsolver_singlestep.py | 32 +++++++++++++++++++ .../schedulers/scheduling_euler_discrete.py | 6 ++-- .../schedulers/scheduling_heun_discrete.py | 32 +++++++++++++++++++ .../scheduling_k_dpm_2_ancestral_discrete.py | 30 +++++++++++++++++ .../schedulers/scheduling_k_dpm_2_discrete.py | 30 +++++++++++++++++ .../schedulers/scheduling_lms_discrete.py | 30 +++++++++++++++++ .../schedulers/scheduling_sasolver.py | 30 +++++++++++++++++ .../schedulers/scheduling_unipc_multistep.py | 30 +++++++++++++++++ 13 files changed, 345 insertions(+), 9 deletions(-) diff --git a/docs/source/en/api/schedulers/overview.md b/docs/source/en/api/schedulers/overview.md index ea8fcb15afd6..2150357cc2b9 100644 --- a/docs/source/en/api/schedulers/overview.md +++ b/docs/source/en/api/schedulers/overview.md @@ -46,11 +46,12 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso | N/A | [`UniPCMultistepScheduler`] | | ## Noise schedules and schedule types -| A1111/k-diffusion | 🤗 Diffusers | -|---------------------|----------------------------------------| -| Karras | init with `use_karras_sigmas=True` | -| sgm_uniform | init with `timestep_spacing="trailing"`| -| simple | init with `timestep_spacing="trailing"`| +| A1111/k-diffusion | 🤗 Diffusers | +|--------------------------|----------------------------------------------------------------------------| +| Karras | init with `use_karras_sigmas=True` | +| sgm_uniform | init with `timestep_spacing="trailing"` | +| simple | init with `timestep_spacing="trailing"` | +| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` | All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers. diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 11073ce491d3..3b26befac64e 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -111,6 +111,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -138,9 +140,12 @@ def __init__( solver_type: str = "logrho", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -255,6 +260,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 @@ -366,6 +374,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + def convert_model_output( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 4472a06c3428..924eefb0e98d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -161,6 +161,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. use_lu_lambdas (`bool`, *optional*, defaults to `False`): Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of @@ -206,6 +208,7 @@ def __init__( lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, use_lu_lambdas: Optional[bool] = False, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), @@ -214,6 +217,8 @@ def __init__( steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) @@ -330,6 +335,8 @@ def set_timesteps( raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") if timesteps is not None and self.config.use_lu_lambdas: raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") + if timesteps is not None and self.config.use_exponential_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") if timesteps is not None: timesteps = np.array(timesteps).astype(np.int64) @@ -378,6 +385,9 @@ def set_timesteps( lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) sigmas = np.exp(lambdas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -510,6 +520,28 @@ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return lambdas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + def convert_model_output( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 6628a92ba034..4f024b8c4c75 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -124,6 +124,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -158,11 +160,14 @@ def __init__( lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) @@ -213,6 +218,7 @@ def __init__( self._step_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.use_karras_sigmas = use_karras_sigmas + self.use_exponential_sigmas = use_exponential_sigmas @property def step_index(self): @@ -267,6 +273,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_max = ( @@ -385,6 +394,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output def convert_model_output( self, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 7f2dd081577b..3748de63388a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -160,6 +160,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed is generated. timestep_spacing (`str`, defaults to `"linspace"`): @@ -182,10 +184,13 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -341,6 +346,9 @@ def set_timesteps( if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) @@ -421,6 +429,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + @property def state_in_first_order(self): return self.sample is None diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 1a10fff043fb..353baf08e81d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -123,6 +123,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. final_sigmas_type (`str`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. @@ -154,10 +156,13 @@ def __init__( solver_type: str = "midpoint", lower_order_final: bool = False, use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if algorithm_type == "dpmsolver": deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) @@ -300,6 +305,8 @@ def set_timesteps( raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if timesteps is not None and self.config.use_karras_sigmas: raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.") + if timesteps is not None and self.config.use_exponential_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps @@ -323,6 +330,9 @@ def set_timesteps( sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -452,6 +462,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + def convert_model_output( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 2b74558d3cb7..e79dbe3fe8ab 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -197,6 +197,8 @@ def __init__( rescale_betas_zero_snr: bool = False, final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -338,10 +340,6 @@ def set_timesteps( raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") if timesteps is not None and self.config.use_exponential_sigmas: raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") - if self.config.use_exponential_sigmas and self.config.use_karras_sigmas: - raise ValueError( - "Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`" - ) if ( timesteps is not None and self.config.timestep_type == "continuous" diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 8d0a4a830f42..efcfdeb1d5ef 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -97,6 +97,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -117,11 +119,14 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, clip_sample: Optional[bool] = False, clip_sample_range: float = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -251,6 +256,8 @@ def set_timesteps( raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") if timesteps is not None and self.config.use_karras_sigmas: raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + if timesteps is not None and self.config.use_exponential_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps @@ -286,6 +293,9 @@ def set_timesteps( if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) @@ -354,6 +364,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + @property def state_in_first_order(self): return self.dt is None diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 338412d96bd5..038aa19603ea 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -91,6 +91,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -114,10 +116,13 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -250,6 +255,9 @@ def set_timesteps( if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) self.log_sigmas = torch.from_numpy(log_sigmas).to(device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -346,6 +354,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + @property def state_in_first_order(self): return self.sample is None diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index de66a7b6eaa1..8fbf66832668 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -90,6 +90,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -113,10 +115,13 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -249,6 +254,9 @@ def set_timesteps( if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -359,6 +367,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + def step( self, model_output: Union[torch.Tensor, np.ndarray], diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 9595bb4c71ba..5ef8ffb0dcbf 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -111,6 +111,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -134,10 +136,13 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -289,6 +294,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -362,6 +370,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + def step( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 50049a530800..ad79c69fc714 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -122,6 +122,8 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -156,11 +158,14 @@ def __init__( algorithm_type: str = "data_prediction", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -284,6 +289,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 @@ -395,6 +403,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + def convert_model_output( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 995f85c020ed..78cf0b6d16a7 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -159,6 +159,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -195,11 +197,14 @@ def __init__( disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" rescale_betas_zero_snr: bool = False, ): + if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -329,6 +334,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": @@ -450,6 +458,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() + return sigmas + def convert_model_output( self, model_output: torch.Tensor, From 6ca5a58e43dfb3e922ace5c7f7a54d73280be38b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 25 Sep 2024 15:25:15 +0530 Subject: [PATCH 15/35] [Community Pipeline] Batched implementation of Flux with CFG (#9513) * batched implementation of flux cfg. * style. * readme * remove comments. --- examples/community/README.md | 31 ++++ examples/community/pipeline_flux_with_cfg.py | 148 ++++++++++--------- 2 files changed, 110 insertions(+), 69 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 8f4ab80d680b..e51124e75956 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | Example | Description | Code Example | Colab | Author | |:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:| +|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)| |Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)| | HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) | | Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/toshas/marigold) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) | @@ -82,6 +83,36 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion ## Example usages +### Flux with CFG + +Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md). + +Example usage: + +```py +from diffusers import DiffusionPipeline +import torch + +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, + custom_pipeline="pipeline_flux_with_cfg" +) +pipeline.enable_model_cpu_offload() +prompt = "a watercolor painting of a unicorn" +negative_prompt = "pink" + +img = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + true_cfg=1.5, + guidance_scale=3.5, + num_images_per_prompt=1, + generator=torch.manual_seed(0) +).images[0] +img.save("cfg_flux.png") +``` + ### Differential Diffusion **Eran Levin, Ohad Fried** diff --git a/examples/community/pipeline_flux_with_cfg.py b/examples/community/pipeline_flux_with_cfg.py index 7cfa7b728980..06da6da899cd 100644 --- a/examples/community/pipeline_flux_with_cfg.py +++ b/examples/community/pipeline_flux_with_cfg.py @@ -289,80 +289,104 @@ def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Union[str, List[str]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, + do_true_cfg: bool = False, ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ device = device or self._execution_device - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it + # Set LoRA scale if applicable if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale - # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if do_true_cfg and negative_prompt is not None: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_batch_size = len(negative_prompt) + + if negative_batch_size != batch_size: + raise ValueError( + f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})" + ) + + # Concatenate prompts + prompts = prompt + negative_prompt + prompts_2 = ( + prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None + ) + else: + prompts = prompt + prompts_2 = prompt_2 if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + if prompts_2 is None: + prompts_2 = prompts - # We only use the pooled prompt output from the CLIPTextModel + # Get pooled prompt embeddings from CLIPTextModel pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, + prompt=prompts, device=device, num_images_per_prompt=num_images_per_prompt, ) prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, + prompt=prompts_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) + if do_true_cfg and negative_prompt is not None: + # Split embeddings back into positive and negative parts + total_batch_size = batch_size * num_images_per_prompt + positive_indices = slice(0, total_batch_size) + negative_indices = slice(total_batch_size, 2 * total_batch_size) + + positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices] + negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices] + + positive_prompt_embeds = prompt_embeds[positive_indices] + negative_prompt_embeds = prompt_embeds[negative_indices] + + pooled_prompt_embeds = positive_pooled_prompt_embeds + prompt_embeds = positive_prompt_embeds + + # Unscale LoRA layers if self.text_encoder is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - return prompt_embeds, pooled_prompt_embeds, text_ids + if do_true_cfg and negative_prompt is not None: + return ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ) + else: + return prompt_embeds, pooled_prompt_embeds, text_ids, None, None def check_inputs( self, @@ -687,38 +711,33 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) + do_true_cfg = true_cfg > 1 and negative_prompt is not None ( prompt_embeds, pooled_prompt_embeds, text_ids, + negative_prompt_embeds, + negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, + do_true_cfg=do_true_cfg, ) - # perform "real" CFG as suggested for distilled Flux models in https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md - do_true_cfg = true_cfg > 1 and negative_prompt is not None if do_true_cfg: - ( - negative_prompt_embeds, - negative_pooled_prompt_embeds, - negative_text_ids, - ) = self.encode_prompt( - prompt=negative_prompt, - prompt_2=negative_prompt_2, - prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=negative_pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) + # Concatenate embeddings + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -754,24 +773,26 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue + latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, @@ -783,18 +804,7 @@ def __call__( )[0] if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - + neg_noise_pred, noise_pred = noise_pred.chunk(2) noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 From 065ce07ac334ebe24e204515d39a388ab560787e Mon Sep 17 00:00:00 2001 From: Lee Penkman Date: Thu, 26 Sep 2024 03:54:36 +1200 Subject: [PATCH 16/35] Update community_projects.md (#9266) --- docs/source/en/community_projects.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/community_projects.md b/docs/source/en/community_projects.md index 8a00a7fda7ed..4ab1829871c8 100644 --- a/docs/source/en/community_projects.md +++ b/docs/source/en/community_projects.md @@ -75,4 +75,8 @@ Happy exploring, and thank you for being part of the Diffusers community! StreamDiffusion A Pipeline-Level Solution for Real-Time Interactive Generation + + Stable Diffusion Server + A server configured for Inpainting/Generation/img2img with one stable diffusion model + From d9c969172d97796bf03066b0af72d3a20410bf44 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Wed, 25 Sep 2024 09:33:54 -0700 Subject: [PATCH 17/35] [docs] Model sharding (#9521) * flux shard * feedback --- docs/source/en/_toctree.yml | 2 +- .../en/training/distributed_inference.md | 130 +++++++++++++++++- 2 files changed, 130 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a282ca717a9f..b331e4b13760 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -56,7 +56,7 @@ - local: using-diffusers/overview_techniques title: Overview - local: training/distributed_inference - title: Distributed inference with multiple GPUs + title: Distributed inference - local: using-diffusers/merge_loras title: Merge LoRAs - local: using-diffusers/scheduler_features diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 5c371033dfd5..cd642d6aca07 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Distributed inference with multiple GPUs +# Distributed inference On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel. @@ -109,3 +109,131 @@ torchrun run_distributed.py --nproc_per_node=2 > [!TIP] > You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more. + +## Model sharding + +Modern diffusion systems such as [Flux](../api/pipelines/flux) are very large and have multiple models. For example, [Flux.1-Dev](https://hf.co/black-forest-labs/FLUX.1-dev) is made up of two text encoders - [T5-XXL](https://hf.co/google/t5-v1_1-xxl) and [CLIP-L](https://hf.co/openai/clip-vit-large-patch14) - a [diffusion transformer](../api/models/flux_transformer), and a [VAE](../api/models/autoencoderkl). With a model this size, it can be challenging to run inference on consumer GPUs. + +Model sharding is a technique that distributes models across GPUs when the models don't fit on a single GPU. The example below assumes two 16GB GPUs are available for inference. + +Start by computing the text embeddings with the text encoders. Keep the text encoders on two GPUs by setting `device_map="balanced"`. The `balanced` strategy evenly distributes the model on all available GPUs. Use the `max_memory` parameter to allocate the maximum amount of memory for each text encoder on each GPU. + +> [!TIP] +> **Only** load the text encoders for this step! The diffusion transformer and VAE are loaded in a later step to preserve memory. + +```py +from diffusers import FluxPipeline +import torch + +prompt = "a photo of a dog with cat-like look" + +pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=None, + vae=None, + device_map="balanced", + max_memory={0: "16GB", 1: "16GB"}, + torch_dtype=torch.bfloat16 +) +with torch.no_grad(): + print("Encoding prompts.") + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + prompt=prompt, prompt_2=None, max_sequence_length=512 + ) +``` + +Once the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer. + +```py +import gc + +def flush(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + +del pipeline.text_encoder +del pipeline.text_encoder_2 +del pipeline.tokenizer +del pipeline.tokenizer_2 +del pipeline + +flush() +``` + +Load the diffusion transformer next which has 12.5B parameters. This time, set `device_map="auto"` to automatically distribute the model across two 16GB GPUs. The `auto` strategy is backed by [Accelerate](https://hf.co/docs/accelerate/index) and available as a part of the [Big Model Inference](https://hf.co/docs/accelerate/concept_guides/big_model_inference) feature. It starts by distributing a model across the fastest device first (GPU) before moving to slower devices like the CPU and hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency. + +```py +from diffusers import FluxTransformer2DModel +import torch + +transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + device_map="auto", + torch_dtype=torch.bfloat16 +) +``` + +> [!TIP] +> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. + +Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet. + +```py +pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", , + text_encoder=None, + text_encoder_2=None, + tokenizer=None, + tokenizer_2=None, + vae=None, + transformer=transformer, + torch_dtype=torch.bfloat16 +) + +print("Running denoising.") +height, width = 768, 1360 +latents = pipeline( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=50, + guidance_scale=3.5, + height=height, + width=width, + output_type="latent", +).images +``` + +Remove the pipeline and transformer from memory as they're no longer needed. + +```py +del pipeline.transformer +del pipeline + +flush() +``` + +Finally, decode the latents with the VAE into an image. The VAE is typically small enough to be loaded on a single GPU. + +```py +from diffusers import AutoencoderKL +from diffusers.image_processor import VaeImageProcessor +import torch + +vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda") +vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) +image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + +with torch.no_grad(): + print("Running decoding.") + latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + + image = vae.decode(latents, return_dict=False)[0] + image = image_processor.postprocess(image, output_type="pil") + image[0].save("split_transformer.png") +``` + +By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs. From c76e88405c1464357a8fa508bef37ce561899e2b Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 25 Sep 2024 11:00:57 -1000 Subject: [PATCH 18/35] update get_parameter_dtype (#9526) * up * Update src/diffusers/models/modeling_utils.py Co-authored-by: Aryan --------- Co-authored-by: Aryan --- src/diffusers/models/modeling_utils.py | 30 +++++++++++--------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index cfe692dcc54a..9e0c50e8b37b 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -93,24 +93,20 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: try: - params = tuple(parameter.parameters()) - if len(params) > 0: - return params[0].dtype - - buffers = tuple(parameter.buffers()) - if len(buffers) > 0: - return buffers[0].dtype - + return next(parameter.parameters()).dtype except StopIteration: - # For torch.nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].dtype + try: + return next(parameter.buffers()).dtype + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype class ModelMixin(torch.nn.Module, PushToHubMixin): From aa3c46d99acfaa145bdf620f821de9b409c2e6c6 Mon Sep 17 00:00:00 2001 From: v2ray <60914079+LagPixelLOL@users.noreply.github.com> Date: Thu, 26 Sep 2024 06:26:58 +0800 Subject: [PATCH 19/35] [Doc] Improved level of clarity for latents_to_rgb. (#9529) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixed latents_to_rgb doc. Co-authored-by: Álvaro Somoza --- docs/source/en/using-diffusers/callback.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md index d4d23d62540f..68c621ffc50d 100644 --- a/docs/source/en/using-diffusers/callback.md +++ b/docs/source/en/using-diffusers/callback.md @@ -171,14 +171,13 @@ def latents_to_rgb(latents): weights = ( (60, -60, 25, -70), (60, -5, 15, -50), - (60, 10, -5, -35) + (60, 10, -5, -35), ) weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1) - image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy() - image_array = image_array.transpose(1, 2, 0) + image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) return Image.fromarray(image_array) ``` @@ -189,7 +188,7 @@ def latents_to_rgb(latents): def decode_tensors(pipe, step, timestep, callback_kwargs): latents = callback_kwargs["latents"] - image = latents_to_rgb(latents) + image = latents_to_rgb(latents[0]) image.save(f"{step}.png") return callback_kwargs From 1c6ede9371815dcb27bffb3ec365799ebdd2f04e Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 26 Sep 2024 00:30:32 +0100 Subject: [PATCH 20/35] [Schedulers] Add beta sigmas / beta noise schedule (#9509) Add beta sigmas / beta noise schedule --- .../schedulers/scheduling_euler_discrete.py | 55 ++++++++++++++++++- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index e79dbe3fe8ab..5c39583356ad 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -20,11 +20,14 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging +from ..utils import BaseOutput, is_scipy_available, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +if is_scipy_available(): + import scipy.stats + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -160,6 +163,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -189,6 +195,7 @@ def __init__( interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, timestep_spacing: str = "linspace", @@ -197,8 +204,12 @@ def __init__( rescale_betas_zero_snr: bool = False, final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -241,6 +252,7 @@ def __init__( self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas self.use_exponential_sigmas = use_exponential_sigmas + self.use_beta_sigmas = use_beta_sigmas self._step_index = None self._begin_index = None @@ -340,6 +352,8 @@ def set_timesteps( raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") if timesteps is not None and self.config.use_exponential_sigmas: raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") + if timesteps is not None and self.config.use_beta_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.") if ( timesteps is not None and self.config.timestep_type == "continuous" @@ -408,6 +422,10 @@ def set_timesteps( sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 elif self.config.final_sigmas_type == "zero": @@ -502,6 +520,37 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps From 9cd37557d581dd30fb6031ae30bd583443c3effd Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 25 Sep 2024 19:09:54 -1000 Subject: [PATCH 21/35] flux controlnet fix (control_modes batch & others) (#9507) * flux controlnet mode to take into account batch size * incorporate yiyixuxu's suggestions (cleaner logic) as well as clean up control mode handling for multi case * fix * fix use_guidance when controlnet is a multi and does not have config --------- Co-authored-by: Christopher Beckham Co-authored-by: Sayak Paul --- src/diffusers/models/controlnet_flux.py | 23 +++++------ .../flux/pipeline_flux_controlnet.py | 39 ++++++++++++------- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 036e5654a98e..88ad49d2b776 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -502,16 +502,17 @@ def forward( control_block_samples = block_samples control_single_block_samples = single_block_samples else: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples, block_samples) - ] - - control_single_block_samples = [ - control_single_block_sample + block_sample - for control_single_block_sample, block_sample in zip( - control_single_block_samples, single_block_samples - ) - ] + if block_samples is not None and control_block_samples is not None: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + if single_block_samples is not None and control_single_block_samples is not None: + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] return control_block_samples, control_single_block_samples diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 11b71b1cbece..6c072c482020 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -747,10 +747,12 @@ def __call__( width_control_image, ) - # set control mode + # Here we ensure that `control_mode` has the same length as the control_image. if control_mode is not None: + if not isinstance(control_mode, int): + raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`") control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] @@ -785,16 +787,22 @@ def __call__( control_image = control_images + # Here we ensure that `control_mode` has the same length as the control_image. + if isinstance(control_mode, list) and len(control_mode) != len(control_image): + raise ValueError( + "For Multi-ControlNet, `control_mode` must be a list of the same " + + " length as the number of controlnets (control images) specified" + ) + if not isinstance(control_mode, list): + control_mode = [control_mode] * len(control_image) # set control mode - control_mode_ = [] - if isinstance(control_mode, list): - for cmode in control_mode: - if cmode is None: - control_mode_.append(-1) - else: - control_mode_.append(cmode) - control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + control_modes = [] + for cmode in control_mode: + if cmode is None: + cmode = -1 + control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long) + control_modes.append(control_mode) + control_mode = control_modes # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -840,9 +848,12 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - guidance = ( - torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None - ) + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + + guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None guidance = guidance.expand(latents.shape[0]) if guidance is not None else None # controlnet From 066ea374c860b74898ba6cc996c33304f3b973de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Thu, 26 Sep 2024 05:10:15 -0300 Subject: [PATCH 22/35] [Tests] Fix ChatGLMTokenizer (#9536) fix --- src/diffusers/pipelines/kolors/tokenizer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/kolors/tokenizer.py b/src/diffusers/pipelines/kolors/tokenizer.py index a7b942f4fd22..fa241b920c97 100644 --- a/src/diffusers/pipelines/kolors/tokenizer.py +++ b/src/diffusers/pipelines/kolors/tokenizer.py @@ -277,6 +277,7 @@ def _pad( padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, + padding_side: Optional[bool] = None, ) -> dict: """ Pad encoded inputs (on left/right and up to predefined length or max length in the batch) @@ -298,6 +299,9 @@ def _pad( pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ From 665c6b47a23bc841ad1440c4fe9cbb1782258656 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 26 Sep 2024 22:12:07 +0530 Subject: [PATCH 23/35] [bug] Precedence of operations in VAE should be slicing -> tiling (#9342) * bugfix: precedence of operations should be slicing -> tiling * fix typo * fix another typo * deprecate current implementation of tiled_encode and use new impl * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu * Update src/diffusers/models/autoencoders/autoencoder_kl.py --------- Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- .../models/autoencoders/autoencoder_kl.py | 82 ++++++++++++++++--- 1 file changed, 71 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 161770c67cf8..99a7da4a0b6f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -18,6 +18,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import deprecate from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -245,6 +246,18 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True @@ -261,21 +274,13 @@ def encode( The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): - return self.tiled_encode(x, return_dict=return_dict) - if self.use_slicing and x.shape[0] > 1: - encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: - h = self.encoder(x) - - if self.quant_conv is not None: - moments = self.quant_conv(h) - else: - moments = h + h = self._encode(x) - posterior = DiagonalGaussianDistribution(moments) + posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) @@ -337,6 +342,54 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. @@ -356,6 +409,13 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value." + ) + deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent From 2daedc0ad3c868931689b2cbbe6b0243a2f3166b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 27 Sep 2024 07:32:43 +0530 Subject: [PATCH 24/35] [LoRA] make set_adapters() method more robust. (#9535) * make set_adapters() method more robust. * remove patch * better and concise code. * Update src/diffusers/loaders/lora_base.py Co-authored-by: YiYi Xu --------- Co-authored-by: YiYi Xu --- src/diffusers/loaders/lora_base.py | 14 +++++++--- tests/lora/utils.py | 44 ++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 89bb498a3acd..e124b6eeacf3 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -532,13 +532,19 @@ def set_adapters( ) list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} - all_adapters = { - adapter for adapters in list_adapters.values() for adapter in adapters - } # eg ["adapter1", "adapter2"] + # eg ["adapter1", "adapter2"] + all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters} + missing_adapters = set(adapter_names) - all_adapters + if len(missing_adapters) > 0: + raise ValueError( + f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}." + ) + + # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} invert_list_adapters = { adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] for adapter in all_adapters - } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} + } # Decompose weights into weights for denoiser and text encoders. _component_adapter_weights = {} diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 939b749c286a..43c45daaa322 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -929,12 +929,24 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter outputs should be different.", + ) # Fuse and unfuse should lead to the same results self.assertFalse( @@ -960,6 +972,38 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): "output with no lora and output with lora disabled should give same results", ) + def test_wrong_adapter_name_raises_error(self): + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + with self.assertRaises(ValueError) as err_context: + pipe.set_adapters("test") + + self.assertTrue("not in the list of present adapters" in str(err_context.exception)) + + # test this works. + pipe.set_adapters("adapter-1") + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + def test_simple_inference_with_text_denoiser_block_scale(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches From 534848c370dafddc3f9fb95151f472da39486206 Mon Sep 17 00:00:00 2001 From: PromeAI Date: Fri, 27 Sep 2024 16:01:47 +0800 Subject: [PATCH 25/35] [examples] add train flux-controlnet scripts in example. (#9324) * add train flux-controlnet scripts in example. * fix error * fix subfolder error * fix preprocess error * Update examples/controlnet/README_flux.md Co-authored-by: Sayak Paul * Update examples/controlnet/README_flux.md Co-authored-by: Sayak Paul * fix readme * fix note error * add some Tutorial for deepspeed * fix some Format Error * add dataset_path example * remove print, add guidance_scale CLI, readable apply * Update examples/controlnet/README_flux.md Co-authored-by: Sayak Paul * update,push_to_hub,save_weight_dtype,static method,clear_objs_and_retain_memory,report_to=wandb * add push to hub in readme * apply weighting schemes * add note * Update examples/controlnet/README_flux.md Co-authored-by: Sayak Paul * make code style and quality * fix some unnoticed error * make code style and quality * add example controlnet in readme * add test controlnet * rm Remove duplicate notes * Fix formatting errors * add new control image * add model cpu offload * update help for adafactor * make quality & style * make quality and style * rename flux_controlnet_model_name_or_path * fix back src/diffusers/pipelines/flux/pipeline_flux_controlnet.py * fix dtype error by pre calculate text emb * rm image save * quality fix * fix test * fix tiny flux train error * change report to to tensorboard * fix save name error when test * Fix shrinking errors --------- Co-authored-by: YiYi Xu Co-authored-by: Sayak Paul Co-authored-by: Your Name --- examples/controlnet/README_flux.md | 430 ++++++ examples/controlnet/requirements_flux.txt | 9 + examples/controlnet/test_controlnet.py | 25 + examples/controlnet/train_controlnet_flux.py | 1434 ++++++++++++++++++ 4 files changed, 1898 insertions(+) create mode 100644 examples/controlnet/README_flux.md create mode 100644 examples/controlnet/requirements_flux.txt create mode 100644 examples/controlnet/train_controlnet_flux.py diff --git a/examples/controlnet/README_flux.md b/examples/controlnet/README_flux.md new file mode 100644 index 000000000000..d8be36a6e17a --- /dev/null +++ b/examples/controlnet/README_flux.md @@ -0,0 +1,430 @@ +# ControlNet training example for FLUX + +The `train_controlnet_flux.py` script shows how to implement the ControlNet training procedure and adapt it for [FLUX](https://github.com/black-forest-labs/flux). + +Training script provided by LibAI, which is an institution dedicated to the progress and achievement of artificial general intelligence. LibAI is the developer of [cutout.pro](https://www.cutout.pro/) and [promeai.pro](https://www.promeai.pro/). +> [!NOTE] +> **Memory consumption** +> +> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual. + +> **Gated access** +> +> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: `huggingface-cli login` + + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/controlnet` folder and run +```bash +pip install -r requirements_flux.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. + +## Custom Datasets + +We support dataset formats: +The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. To use our example, add `--dataset_name=fusing/fill50k \` to the script and remove line `--jsonl_for_train` mentioned below. + + +We also support importing data from jsonl(xxx.jsonl),using `--jsonl_for_train` to enable it, here is a brief example of jsonl files: +```sh +{"image": "xxx", "text": "xxx", "conditioning_image": "xxx"} +{"image": "xxx", "text": "xxx", "conditioning_image": "xxx"} +``` + +## Training + +Our training examples use two test conditioning images. They can be downloaded by running + +```sh +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png +``` + +Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub. + +we can define the num_layers, num_single_layers, which determines the size of the control(default values are num_layers=4, num_single_layers=10) + + +```bash +accelerate launch train_controlnet_flux.py \ + --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ + --dataset_name=fusing/fill50k \ + --conditioning_image_column=conditioning_image \ + --image_column=image \ + --caption_column=text \ + --output_dir="path to save model" \ + --mixed_precision="bf16" \ + --resolution=512 \ + --learning_rate=1e-5 \ + --max_train_steps=15000 \ + --validation_steps=100 \ + --checkpointing_steps=200 \ + --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --report_to="wandb" \ + --num_double_layers=4 \ + --num_single_layers=0 \ + --seed=42 \ + --push_to_hub \ +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. +* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +Our experiments were conducted on a single 80GB A100 GPU. + +### Inference + +Once training is done, we can perform inference like so: + +```python +import torch +from diffusers.utils import load_image +from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline +from diffusers.models.controlnet_flux import FluxControlNetModel + +base_model = 'black-forest-labs/FLUX.1-dev' +controlnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai' +controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) +pipe = FluxControlNetPipeline.from_pretrained( + base_model, + controlnet=controlnet, + torch_dtype=torch.bfloat16 +) +# enable memory optimizations +pipe.enable_model_cpu_offload() + +control_image = load_image("https://huggingface.co/promeai/FLUX.1-controlnet-lineart-promeai/resolve/main/images/example-control.jpg")resize((1024, 1024)) +prompt = "cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere" + +image = pipe( + prompt, + control_image=control_image, + controlnet_conditioning_scale=0.6, + num_inference_steps=28, + guidance_scale=3.5, +).images[0] +image.save("./output.png") +``` + +## Apply Deepspeed Zero3 + +This is an experimental process, I am not sure if it is suitable for everyone, we used this process to successfully train 512 resolution on A100(40g) * 8. +Please modify some of the code in the script. +### 1.Customize zero3 settings + +Copy the **accelerate_config_zero3.yaml**,modify `num_processes` according to the number of gpus you want to use: + +```bash +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 8 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +``` + +### 2.Precompute all inputs (latent, embeddings) + +In the train_controlnet_flux.py, We need to pre-calculate all parameters and put them into batches.So we first need to rewrite the `compute_embeddings` function. + +```python +def compute_embeddings(batch, proportion_empty_prompts, vae, flux_controlnet_pipeline, weight_dtype, is_train=True): + + ### compute text embeddings + prompt_batch = batch[args.caption_column] + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + prompt_batch = captions + prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt( + prompt_batch, prompt_2=prompt_batch + ) + prompt_embeds = prompt_embeds.to(dtype=weight_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype) + text_ids = text_ids.to(dtype=weight_dtype) + + # text_ids [512,3] to [bs,512,3] + text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1) + + ### compute latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + return latents + + # vae encode + pixel_values = batch["pixel_values"] + pixel_values = torch.stack([image for image in pixel_values]).to(dtype=weight_dtype).to(vae.device) + pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample() + pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor + pixel_latents = _pack_latents( + pixel_latents_tmp, + pixel_values.shape[0], + pixel_latents_tmp.shape[1], + pixel_latents_tmp.shape[2], + pixel_latents_tmp.shape[3], + ) + + control_values = batch["conditioning_pixel_values"] + control_values = torch.stack([image for image in control_values]).to(dtype=weight_dtype).to(vae.device) + control_latents = vae.encode(control_values).latent_dist.sample() + control_latents = (control_latents - vae.config.shift_factor) * vae.config.scaling_factor + control_latents = _pack_latents( + control_latents, + control_values.shape[0], + control_latents.shape[1], + control_latents.shape[2], + control_latents.shape[3], + ) + + # copied from pipeline_flux_controlnet + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + latent_image_ids = _prepare_latent_image_ids( + batch_size=pixel_latents_tmp.shape[0], + height=pixel_latents_tmp.shape[2], + width=pixel_latents_tmp.shape[3], + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + + # unet_added_cond_kwargs = {"pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids} + return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids, "pixel_latents": pixel_latents, "control_latents": control_latents, "latent_image_ids": latent_image_ids} +``` + +Because we need images to pass through vae, we need to preprocess the images in the dataset first. At the same time, vae requires more gpu memory, so you may need to modify the `batch_size` below +```diff ++train_dataset = prepare_train_dataset(train_dataset, accelerator) +with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + train_dataset = train_dataset.map( +- compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=100 ++ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=10 + ) + +del text_encoders, tokenizers +gc.collect() +torch.cuda.empty_cache() + +# Then get the training dataset ready to be passed to the dataloader. +-train_dataset = prepare_train_dataset(train_dataset, accelerator) +``` +### 3.Redefine the behavior of getting batchsize + +Now that we have all the preprocessing done, we need to modify the `collate_fn` function. + +```python +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + pixel_latents = torch.stack([torch.tensor(example["pixel_latents"]) for example in examples]) + pixel_latents = pixel_latents.to(memory_format=torch.contiguous_format).float() + + control_latents = torch.stack([torch.tensor(example["control_latents"]) for example in examples]) + control_latents = control_latents.to(memory_format=torch.contiguous_format).float() + + latent_image_ids= torch.stack([torch.tensor(example["latent_image_ids"]) for example in examples]) + + prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + + pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]) + text_ids = torch.stack([torch.tensor(example["text_ids"]) for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "pixel_latents": pixel_latents, + "control_latents": control_latents, + "latent_image_ids": latent_image_ids, + "prompt_ids": prompt_ids, + "unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids}, + } +``` +Finally, we just need to modify the way of obtaining various parameters during training. +```python +for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(flux_controlnet): + # Convert images to latent space + pixel_latents = batch["pixel_latents"].to(dtype=weight_dtype) + control_image = batch["control_latents"].to(dtype=weight_dtype) + latent_image_ids = batch["latent_image_ids"].to(dtype=weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype) + bsz = pixel_latents.shape[0] + + # Sample a random timestep for each image + t = torch.sigmoid(torch.randn((bsz,), device=accelerator.device, dtype=weight_dtype)) + + # apply flow matching + noisy_latents = ( + 1 - t.unsqueeze(1).unsqueeze(2).repeat(1, pixel_latents.shape[1], pixel_latents.shape[2]) + ) * pixel_latents + t.unsqueeze(1).unsqueeze(2).repeat( + 1, pixel_latents.shape[1], pixel_latents.shape[2] + ) * noise + + guidance_vec = torch.full( + (noisy_latents.shape[0],), 3.5, device=noisy_latents.device, dtype=weight_dtype + ) + + controlnet_block_samples, controlnet_single_block_samples = flux_controlnet( + hidden_states=noisy_latents, + controlnet_cond=control_image, + timestep=t, + guidance=guidance_vec, + pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype), + encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype), + txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype), + img_ids=latent_image_ids[0], + return_dict=False, + ) + + noise_pred = flux_transformer( + hidden_states=noisy_latents, + timestep=t, + guidance=guidance_vec, + pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype), + encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype), + controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples] + if controlnet_block_samples is not None + else None, + controlnet_single_block_samples=[ + sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples + ] + if controlnet_single_block_samples is not None + else None, + txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype), + img_ids=latent_image_ids[0], + return_dict=False, + )[0] +``` +Congratulations! You have completed all the required code modifications required for deepspeedzero3. + +### 4.Training with deepspeedzero3 + +Start!!! + +```bash +export pretrained_model_name_or_path='flux-dev-model-path' +export MODEL_TYPE='train_model_type' +export TRAIN_JSON_FILE="your_json_file" +export CONTROL_TYPE='control_preprocessor_type' +export CAPTION_COLUMN='caption_column' + +export CACHE_DIR="/data/train_csr/.cache/huggingface/" +export OUTPUT_DIR='/data/train_csr/FLUX/MODEL_OUT/'$MODEL_TYPE +# The first step is to use Python to precompute all caches.Replace the first line below with this line. (I am not sure why using acclerate would cause problems.) + +CUDA_VISIBLE_DEVICES=0 python3 train_controlnet_flux.py \ + +# The second step is to use the above accelerate config to train +accelerate launch --config_file "./accelerate_config_zero3.yaml" train_controlnet_flux.py \ + --pretrained_model_name_or_path=$pretrained_model_name_or_path \ + --jsonl_for_train=$TRAIN_JSON_FILE \ + --conditioning_image_column=$CONTROL_TYPE \ + --image_column=image \ + --caption_column=$CAPTION_COLUMN\ + --cache_dir=$CACHE_DIR \ + --tracker_project_name=$MODEL_TYPE \ + --output_dir=$OUTPUT_DIR \ + --max_train_steps=500000 \ + --mixed_precision bf16 \ + --checkpointing_steps=1000 \ + --gradient_accumulation_steps=8 \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=1e-5 \ + --num_double_layers=4 \ + --num_single_layers=0 \ + --gradient_checkpointing \ + --resume_from_checkpoint="latest" \ + # --use_adafactor \ dont use + # --validation_steps=3 \ not support + # --validation_image $VALIDATION_IMAGE \ not support + # --validation_prompt "xxx" \ not support +``` \ No newline at end of file diff --git a/examples/controlnet/requirements_flux.txt b/examples/controlnet/requirements_flux.txt new file mode 100644 index 000000000000..388444fbc65b --- /dev/null +++ b/examples/controlnet/requirements_flux.txt @@ -0,0 +1,9 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +ftfy +tensorboard +Jinja2 +datasets +wandb +SentencePiece \ No newline at end of file diff --git a/examples/controlnet/test_controlnet.py b/examples/controlnet/test_controlnet.py index 77b5614c7fb0..3c508f80f1a4 100644 --- a/examples/controlnet/test_controlnet.py +++ b/examples/controlnet/test_controlnet.py @@ -136,3 +136,28 @@ def test_controlnet_sd3(self): run_command(self._launch_args + test_args) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) + + +class ControlNetflux(ExamplesTestsAccelerate): + def test_controlnet_flux(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/controlnet/train_controlnet_flux.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe + --output_dir={tmpdir} + --dataset_name=hf-internal-testing/fill10 + --conditioning_image_column=conditioning_image + --image_column=image + --caption_column=text + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + --num_double_layers=1 + --num_single_layers=1 + """.split() + + run_command(self._launch_args + test_args) + + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py new file mode 100644 index 000000000000..e344a9b1e2a5 --- /dev/null +++ b/examples/controlnet/train_controlnet_flux.py @@ -0,0 +1,1434 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import functools +import logging +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import ( + AutoTokenizer, + CLIPTextModel, + T5EncoderModel, +) + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxTransformer2DModel, +) +from diffusers.models.controlnet_flux import FluxControlNetModel +from diffusers.optimization import get_scheduler +from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline +from diffusers.training_utils import clear_objs_and_retain_memory, compute_density_for_timestep_sampling +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0.dev0") + +logger = get_logger(__name__) +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False + + +def log_validation( + vae, flux_transformer, flux_controlnet, args, accelerator, weight_dtype, step, is_final_validation=False +): + logger.info("Running validation... ") + + if not is_final_validation: + flux_controlnet = accelerator.unwrap_model(flux_controlnet) + pipeline = FluxControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + controlnet=flux_controlnet, + transformer=flux_transformer, + torch_dtype=torch.bfloat16, + ) + else: + flux_controlnet = FluxControlNetModel.from_pretrained( + args.output_dir, torch_dtype=torch.bfloat16, variant=args.save_weight_dtype + ) + pipeline = FluxControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + controlnet=flux_controlnet, + transformer=flux_transformer, + torch_dtype=torch.bfloat16, + ) + + pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + if is_final_validation or torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + from diffusers.utils import load_image + + validation_image = load_image(validation_image) + # maybe need to inference on 1024 to get a good image + validation_image = validation_image.resize((args.resolution, args.resolution)) + + images = [] + + # pre calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + validation_prompt, prompt_2=validation_prompt + ) + for _ in range(args.num_validation_images): + with autocast_ctx: + # need to fix in pipeline_flux_controlnet + image = pipeline( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + control_image=validation_image, + num_inference_steps=28, + controlnet_conditioning_scale=0.7, + guidance_scale=3.5, + generator=generator, + ).images[0] + images.append(image) + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + clear_objs_and_retain_memory([pipeline]) + return image_logs + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "flux", + "flux-diffusers", + "text-to-image", + "diffusers", + "controlnet", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--use_adafactor", + action="store_true", + help=( + "Adafactor is a stochastic optimization method based on Adam that reduces memory usage while retaining" + "the empirical benefits of adaptivity. This is achieved through maintaining a factored representation " + "of the squared gradient accumulator across training steps." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_double_layers", + type=int, + default=4, + help="Number of double layers in the controlnet (default: 4).", + ) + parser.add_argument( + "--num_single_layers", + type=int, + default=4, + help="Number of single layers in the controlnet (default: 4).", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=2, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="flux_train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument( + "--jsonl_for_train", + type=str, + default=None, + help="Path to the jsonl file containing the training data.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the guidance scale used for transformer.", + ) + + parser.add_argument( + "--save_weight_dtype", + type=str, + default="fp32", + choices=[ + "fp16", + "bf16", + "fp32", + ], + help=("Preserve precision type according to selected weight"), + ) + + parser.add_argument( + "--weighting_scheme", + type=str, + default="logit_normal", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--enable_model_cpu_offload", + action="store_true", + help="Enable model cpu offload and save memory.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.jsonl_for_train is None: + raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`") + + if args.dataset_name is not None and args.jsonl_for_train is not None: + raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def get_train_dataset(args, accelerator): + dataset = None + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + if args.jsonl_for_train is not None: + # load from json + dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir) + dataset = dataset.flatten_indices() + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + with accelerator.main_process_first(): + train_dataset = dataset["train"].shuffle(seed=args.seed) + if args.max_train_samples is not None: + train_dataset = train_dataset.select(range(args.max_train_samples)) + return train_dataset + + +def prepare_train_dataset(dataset, accelerator): + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + images = [ + (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB")) + for image in examples[args.image_column] + ] + images = [image_transforms(image) for image in images] + + conditioning_images = [ + (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB")) + for image in examples[args.conditioning_image_column] + ] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + + return examples + + with accelerator.main_process_first(): + dataset = dataset.with_transform(preprocess_train) + + return dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + + pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]) + text_ids = torch.stack([torch.tensor(example["text_ids"]) for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "prompt_ids": prompt_ids, + "unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids}, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_out_dir = Path(args.output_dir, args.logging_dir) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir)) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices. + if torch.backends.mps.is_available(): + print("MPS is enabled. Disabling AMP.") + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + # DEBUG, INFO, WARNING, ERROR, CRITICAL + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + # load clip tokenizer + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + # load t5 tokenizer + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + # load clip text encoder + text_encoder_one = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + # load t5 text encoder + text_encoder_two = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + flux_transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + ) + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + flux_controlnet = FluxControlNetModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from transformer") + # we can define the num_layers, num_single_layers, + flux_controlnet = FluxControlNetModel.from_transformer( + flux_transformer, + attention_head_dim=flux_transformer.config["attention_head_dim"], + num_attention_heads=flux_transformer.config["num_attention_heads"], + num_layers=args.num_double_layers, + num_single_layers=args.num_single_layers, + ) + logger.info("all models loaded successfully") + + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae.requires_grad_(False) + flux_transformer.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + flux_controlnet.train() + + # use some pipeline function + flux_controlnet_pipeline = FluxControlNetPipeline( + scheduler=noise_scheduler, + vae=vae, + text_encoder=text_encoder_one, + tokenizer=tokenizer_one, + text_encoder_2=text_encoder_two, + tokenizer_2=tokenizer_two, + transformer=flux_transformer, + controlnet=flux_controlnet, + ) + if args.enable_model_cpu_offload: + flux_controlnet_pipeline.enable_model_cpu_offload() + else: + flux_controlnet_pipeline.to(accelerator.device) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "flux_controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = FluxControlNetModel.from_pretrained(input_dir, subfolder="flux_controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + flux_transformer.enable_npu_flash_attention() + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + flux_transformer.enable_xformers_memory_efficient_attention() + flux_controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + flux_transformer.enable_gradient_checkpointing() + flux_controlnet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if unwrap_model(flux_controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {unwrap_model(flux_controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = flux_controlnet.parameters() + # use adafactor optimizer to save gpu memory + if args.use_adafactor: + from transformers import Adafactor + + optimizer = Adafactor( + params_to_optimize, + lr=args.learning_rate, + scale_parameter=False, + relative_step=False, + # warmup_init=True, + weight_decay=args.adam_weight_decay, + ) + else: + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + vae.to(accelerator.device, dtype=weight_dtype) + flux_transformer.to(accelerator.device, dtype=weight_dtype) + + def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline, weight_dtype, is_train=True): + prompt_batch = batch[args.caption_column] + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + prompt_batch = captions + prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt( + prompt_batch, prompt_2=prompt_batch + ) + prompt_embeds = prompt_embeds.to(dtype=weight_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype) + text_ids = text_ids.to(dtype=weight_dtype) + + # text_ids [512,3] to [bs,512,3] + text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1) + return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids} + + train_dataset = get_train_dataset(args, accelerator) + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + compute_embeddings_fn = functools.partial( + compute_embeddings, + flux_controlnet_pipeline=flux_controlnet_pipeline, + proportion_empty_prompts=args.proportion_empty_prompts, + weight_dtype=weight_dtype, + ) + with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + train_dataset = train_dataset.map( + compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50 + ) + + clear_objs_and_retain_memory([text_encoders, tokenizers]) + + # Then get the training dataset ready to be passed to the dataloader. + train_dataset = prepare_train_dataset(train_dataset, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + # Prepare everything with our `accelerator`. + flux_controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + flux_controlnet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(flux_controlnet): + # Convert images to latent space + # vae encode + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample() + pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor + pixel_latents = FluxControlNetPipeline._pack_latents( + pixel_latents_tmp, + pixel_values.shape[0], + pixel_latents_tmp.shape[1], + pixel_latents_tmp.shape[2], + pixel_latents_tmp.shape[3], + ) + + control_values = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + control_latents = vae.encode(control_values).latent_dist.sample() + control_latents = (control_latents - vae.config.shift_factor) * vae.config.scaling_factor + control_image = FluxControlNetPipeline._pack_latents( + control_latents, + control_values.shape[0], + control_latents.shape[1], + control_latents.shape[2], + control_latents.shape[3], + ) + + latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids( + batch_size=pixel_latents_tmp.shape[0], + height=pixel_latents_tmp.shape[2], + width=pixel_latents_tmp.shape[3], + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + + bsz = pixel_latents.shape[0] + noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype) + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype) + noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise + + # handle guidance + if flux_transformer.config.guidance_embeds: + guidance_vec = torch.full( + (noisy_model_input.shape[0],), + args.guidance_scale, + device=noisy_model_input.device, + dtype=weight_dtype, + ) + else: + guidance_vec = None + + controlnet_block_samples, controlnet_single_block_samples = flux_controlnet( + hidden_states=noisy_model_input, + controlnet_cond=control_image, + timestep=timesteps / 1000, + guidance=guidance_vec, + pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype), + encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype), + txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype), + img_ids=latent_image_ids, + return_dict=False, + ) + + noise_pred = flux_transformer( + hidden_states=noisy_model_input, + timestep=timesteps / 1000, + guidance=guidance_vec, + pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype), + encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype), + controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples] + if controlnet_block_samples is not None + else None, + controlnet_single_block_samples=[ + sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples + ] + if controlnet_single_block_samples is not None + else None, + txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype), + img_ids=latent_image_ids, + return_dict=False, + )[0] + + loss = F.mse_loss(noise_pred.float(), (noise - pixel_latents).float(), reduction="mean") + accelerator.backward(loss) + # Check if the gradient of each model parameter contains NaN + for name, param in flux_controlnet.named_parameters(): + if param.grad is not None and torch.isnan(param.grad).any(): + logger.error(f"Gradient for {name} contains NaN!") + + if accelerator.sync_gradients: + params_to_clip = flux_controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + vae=vae, + flux_transformer=flux_transformer, + flux_controlnet=flux_controlnet, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + ) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_controlnet = unwrap_model(flux_controlnet) + save_weight_dtype = torch.float32 + if args.save_weight_dtype == "fp16": + save_weight_dtype = torch.float16 + elif args.save_weight_dtype == "bf16": + save_weight_dtype = torch.bfloat16 + flux_controlnet.to(save_weight_dtype) + if args.save_weight_dtype != "fp32": + flux_controlnet.save_pretrained(args.output_dir, variant=args.save_weight_dtype) + else: + flux_controlnet.save_pretrained(args.output_dir) + # Run a final round of validation. + # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + vae=vae, + flux_transformer=flux_transformer, + flux_controlnet=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 81cf3b2f155f1de322079af28f625349ee21ec6b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 27 Sep 2024 23:27:09 +0530 Subject: [PATCH 26/35] [Tests] [LoRA] clean up the serialization stuff. (#9512) * clean up the serialization stuff. * better --- tests/lora/utils.py | 114 ++++++++++++++++---------------------------- 1 file changed, 41 insertions(+), 73 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 43c45daaa322..5def867324f4 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -201,6 +201,32 @@ def get_dummy_tokens(self): prepared_inputs["input_ids"] = inputs return prepared_inputs + def _get_lora_state_dicts(self, modules_to_save): + state_dicts = {} + for module_name, module in modules_to_save.items(): + if module is not None: + state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) + return state_dicts + + def _get_modules_to_save(self, pipe, has_denoiser=False): + modules_to_save = {} + lora_loadable_modules = self.pipeline_class._lora_loadable_modules + + if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"): + modules_to_save["text_encoder"] = pipe.text_encoder + + if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"): + modules_to_save["text_encoder_2"] = pipe.text_encoder_2 + + if has_denoiser: + if "unet" in lora_loadable_modules and hasattr(pipe, "unet"): + modules_to_save["unet"] = pipe.unet + + if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"): + modules_to_save["transformer"] = pipe.transformer + + return modules_to_save + def test_simple_inference(self): """ Tests a simple inference and makes sure it works as expected @@ -420,45 +446,21 @@ def test_simple_inference_with_text_lora_save_load(self): images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: - text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) - - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - text_encoder_2_lora_layers=text_encoder_2_state_dict, - safe_serialization=False, - ) - else: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - safe_serialization=False, - ) + modules_to_save = self._get_modules_to_save(pipe) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - if self.has_two_text_encoders: - if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - safe_serialization=False, - ) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), @@ -614,54 +616,20 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: - text_encoder_state_dict = ( - get_peft_model_state_dict(pipe.text_encoder) - if "text_encoder" in self.pipeline_class._lora_loadable_modules - else None + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts ) - denoiser_state_dict = get_peft_model_state_dict(denoiser) - - saving_kwargs = { - "save_directory": tmpdirname, - "safe_serialization": False, - } - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict}) - - if self.unet_kwargs is not None: - saving_kwargs.update({"unet_lora_layers": denoiser_state_dict}) - else: - saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict}) - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) - saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict}) - - self.pipeline_class.save_lora_weights(**saving_kwargs) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), "Loading from saved checkpoints should give same results.", From 11542431a52b02ef4f14b6f53354c79187884827 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 28 Sep 2024 09:57:31 +0530 Subject: [PATCH 27/35] [Core] fix variant-identification. (#9253) * fix variant-idenitification. * fix variant * fix sharded variant checkpoint loading. * Apply suggestions from code review * fixes. * more fixes. * remove print. * fixes * fixes * comments * fixes * apply suggestions. * hub_utils.py * fix test * updates * fixes * fixes * Apply suggestions from code review Co-authored-by: YiYi Xu * updates. * removep patch file. --------- Co-authored-by: YiYi Xu --- src/diffusers/models/model_loading_utils.py | 65 +++++++++++++ src/diffusers/models/modeling_utils.py | 44 +++++---- src/diffusers/pipelines/pipeline_utils.py | 49 +++++++--- src/diffusers/utils/hub_utils.py | 16 +++- tests/models/test_modeling_common.py | 96 ++++++++++++++++++- .../unets/test_models_unet_2d_condition.py | 56 ++++++----- tests/pipelines/test_pipelines.py | 49 +++++++++- tests/pipelines/test_pipelines_common.py | 68 +++++++++++++ 8 files changed, 381 insertions(+), 62 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 969eb5f5fa37..c9eb664443b5 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -31,6 +31,7 @@ WEIGHTS_INDEX_NAME, _add_variant, _get_model_file, + deprecate, is_accelerate_available, is_torch_version, logging, @@ -228,3 +229,67 @@ def _fetch_index_file( index_file = None return index_file + + +def _fetch_index_file_legacy( + is_local, + pretrained_model_name_or_path, + subfolder, + use_safetensors, + cache_dir, + variant, + force_download, + proxies, + local_files_only, + token, + revision, + user_agent, + commit_hash, +): + if is_local: + index_file = Path( + pretrained_model_name_or_path, + subfolder or "", + SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, + ).as_posix() + splits = index_file.split(".") + split_index = -3 if ".cache" in index_file else -2 + splits = splits[:-split_index] + [variant] + splits[-split_index:] + index_file = ".".join(splits) + if os.path.exists(index_file): + deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." + deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False) + index_file = Path(index_file) + else: + index_file = None + else: + if variant is not None: + index_file_in_repo = Path( + subfolder or "", + SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, + ).as_posix() + splits = index_file_in_repo.split(".") + split_index = -2 + splits = splits[:-split_index] + [variant] + splits[-split_index:] + index_file_in_repo = ".".join(splits) + try: + index_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=index_file_in_repo, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=None, + user_agent=user_agent, + commit_hash=commit_hash, + ) + index_file = Path(index_file) + deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." + deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False) + except (EntryNotFoundError, EnvironmentError): + index_file = None + + return index_file diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9e0c50e8b37b..ad3433889fca 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -54,6 +54,7 @@ from .model_loading_utils import ( _determine_device_map, _fetch_index_file, + _fetch_index_file_legacy, _load_state_dict_into_model, load_model_dict_into_meta, load_state_dict, @@ -309,11 +310,9 @@ def save_pretrained( weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = _add_variant(weights_name, variant) - weight_name_split = weights_name.split(".") - if len(weight_name_split) in [2, 3]: - weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:]) - else: - raise ValueError(f"Invalid {weights_name} provided.") + weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) os.makedirs(save_directory, exist_ok=True) @@ -624,21 +623,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P is_sharded = False index_file = None is_local = os.path.isdir(pretrained_model_name_or_path) - index_file = _fetch_index_file( - is_local=is_local, - pretrained_model_name_or_path=pretrained_model_name_or_path, - subfolder=subfolder or "", - use_safetensors=use_safetensors, - cache_dir=cache_dir, - variant=variant, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - user_agent=user_agent, - commit_hash=commit_hash, - ) + index_file_kwargs = { + "is_local": is_local, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "subfolder": subfolder or "", + "use_safetensors": use_safetensors, + "cache_dir": cache_dir, + "variant": variant, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "user_agent": user_agent, + "commit_hash": commit_hash, + } + index_file = _fetch_index_file(**index_file_kwargs) + # In case the index file was not found we still have to consider the legacy format. + # this becomes applicable when the variant is not None. + if variant is not None and (index_file is None or not os.path.exists(index_file)): + index_file = _fetch_index_file_legacy(**index_file_kwargs) if index_file is not None and index_file.is_file(): is_sharded = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index ccd1c9485d0e..6721706b5689 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -50,7 +50,6 @@ DEPRECATED_REVISION_ARGS, BaseOutput, PushToHubMixin, - deprecate, is_accelerate_available, is_accelerate_version, is_torch_npu_available, @@ -58,7 +57,7 @@ logging, numpy_to_pil, ) -from ..utils.hub_utils import load_or_create_model_card, populate_model_card +from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card from ..utils.torch_utils import is_compiled_module @@ -735,6 +734,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: cached_folder = pretrained_model_name_or_path + # The variant filenames can have the legacy sharding checkpoint format that we check and throw + # a warning if detected. + if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant): + warn_msg = ( + f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " + "Please check your files carefully:\n\n" + "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" + "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" + "If you find any files in the deprecated format:\n" + "1. Remove all existing checkpoint files for this variant.\n" + "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" + "This will ensure you're using the most up-to-date and compatible checkpoint format." + ) + logger.warning(warn_msg) + config_dict = cls.load_config(cached_folder) # pop out "_ignore_files" as it is only needed for download @@ -745,6 +759,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` # with variant being `"fp16"`. model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict) + if len(model_variants) == 0 and variant is not None: + error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + raise ValueError(error_message) # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it @@ -1251,6 +1268,22 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: model_info_call_error = e # save error to reraise it if model is not cached locally if not local_files_only: + filenames = {sibling.rfilename for sibling in info.siblings} + if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant): + warn_msg = ( + f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " + "Please check your files carefully:\n\n" + "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" + "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" + "If you find any files in the deprecated format:\n" + "1. Remove all existing checkpoint files for this variant.\n" + "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" + "This will ensure you're using the most up-to-date and compatible checkpoint format." + ) + logger.warning(warn_msg) + + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + config_file = hf_hub_download( pretrained_model_name, cls.config_name, @@ -1267,9 +1300,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # retrieve all folder_names that contain relevant files folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] - filenames = {sibling.rfilename for sibling in info.siblings} - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) - diffusers_module = importlib.import_module(__name__.split(".")[0]) pipelines = getattr(diffusers_module, "pipelines") @@ -1292,13 +1322,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ) if len(variant_filenames) == 0 and variant is not None: - deprecation_message = ( - f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." - f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" - "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" - "modeling files is deprecated." - ) - deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) + error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + raise ValueError(error_message) # remove ignored filenames model_filenames = set(model_filenames) - set(ignore_filenames) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 1cdc02e87328..a79c6cdbfed8 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -271,8 +271,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: splits = weights_name.split(".") - split_index = -2 if weights_name.endswith(".index.json") else -1 - splits = splits[:-split_index] + [variant] + splits[-split_index:] + splits = splits[:-1] + [variant] + splits[-1:] weights_name = ".".join(splits) return weights_name @@ -502,6 +501,19 @@ def _get_checkpoint_shard_files( return cached_folder, sharded_metadata +def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None): + if filenames and folder: + raise ValueError("Both `filenames` and `folder` cannot be provided.") + if not filenames: + filenames = [] + for _, _, files in os.walk(folder): + for file in files: + filenames.append(os.path.basename(file)) + transformers_index_format = r"\d{5}-of-\d{5}" + variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$") + return any(variant_file_re.match(f) is not None for f in filenames) + + class PushToHubMixin: """ A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub. diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index b56ac233ef29..5548fdd0723d 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -27,8 +27,9 @@ import requests_mock import torch from accelerate.utils import compute_module_sizes -from huggingface_hub import ModelCard, delete_repo +from huggingface_hub import ModelCard, delete_repo, snapshot_download from huggingface_hub.utils import is_jinja_available +from parameterized import parameterized from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel @@ -39,7 +40,13 @@ XFormersAttnProcessor, ) from diffusers.training_utils import EMAModel -from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging +from diffusers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + WEIGHTS_INDEX_NAME, + is_torch_npu_available, + is_xformers_available, + logging, +) from diffusers.utils.hub_utils import _add_variant from diffusers.utils.testing_utils import ( CaptureLogger, @@ -100,6 +107,52 @@ def test_accelerate_loading_error_message(self): # make sure that error message states what keys are missing assert "conv_out.bias" in str(error_context.exception) + @parameterized.expand( + [ + ("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", False), + ("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", True), + ("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, False), + ("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, True), + ] + ) + def test_variant_sharded_ckpt_legacy_format_raises_warning(self, repo_id, subfolder, use_local): + def load_model(path): + kwargs = {"variant": "fp16"} + if subfolder: + kwargs["subfolder"] = subfolder + return UNet2DConditionModel.from_pretrained(path, **kwargs) + + with self.assertWarns(FutureWarning) as warning: + if use_local: + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = snapshot_download(repo_id=repo_id) + _ = load_model(tmpdirname) + else: + _ = load_model(repo_id) + + warning_message = str(warning.warnings[0].message) + self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message) + + # Local tests are already covered down below. + @parameterized.expand( + [ + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", None, "fp16"), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "unet", "fp16"), + ("hf-internal-testing/tiny-sd-unet-sharded-no-variants", None, None), + ("hf-internal-testing/tiny-sd-unet-sharded-no-variants-subfolder", "unet", None), + ] + ) + def test_variant_sharded_ckpt_loads_from_hub(self, repo_id, subfolder, variant=None): + def load_model(): + kwargs = {} + if variant: + kwargs["variant"] = variant + if subfolder: + kwargs["subfolder"] = subfolder + return UNet2DConditionModel.from_pretrained(repo_id, **kwargs) + + assert load_model() + def test_cached_files_are_used_when_no_internet(self): # A mock response for an HTTP head request to emulate server down response_mock = mock.Mock() @@ -924,6 +977,7 @@ def test_sharded_checkpoints_with_variant(self): # testing if loading works with the variant when the checkpoint is sharded should be # enough. model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) + index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename))) @@ -976,6 +1030,44 @@ def test_sharded_checkpoints_device_map(self): new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + # This test is okay without a GPU because we're not running any execution. We're just serializing + # and check if the resultant files are following an expected format. + def test_variant_sharded_ckpt_right_format(self): + for use_safe in [True, False]: + extension = ".safetensors" if use_safe else ".bin" + config, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + + model_size = compute_module_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. + variant = "fp16" + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained( + tmp_dir, variant=variant, max_shard_size=f"{max_shard_size}KB", safe_serialization=use_safe + ) + index_variant = _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safe else WEIGHTS_INDEX_NAME, variant) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_variant))) + + # Now check if the right number of shards exists. First, let's get the number of shards. + # Since this number can be dependent on the model being tested, it's important that we calculate it + # instead of hardcoding it. + expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_variant)) + actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(extension)]) + self.assertTrue(actual_num_shards == expected_num_shards) + + # Check if the variant is present as a substring in the checkpoints. + shard_files = [ + file + for file in os.listdir(tmp_dir) + if file.endswith(extension) or ("index" in file and "json" in file) + ] + assert all(variant in f for f in shard_files) + + # Check if the sharded checkpoints were serialized in the right format. + shard_files = [file for file in os.listdir(tmp_dir) if file.endswith(extension)] + # Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors + assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) + @is_staging_test class ModelPushToHubTester(unittest.TestCase): diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index f354950b6075..37d55cedeb28 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1036,9 +1036,15 @@ def test_ip_adapter_plus(self): assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) @require_torch_gpu - def test_load_sharded_checkpoint_from_hub(self): + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), + ] + ) + def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy") + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) @@ -1046,11 +1052,15 @@ def test_load_sharded_checkpoint_from_hub(self): assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_gpu - def test_load_sharded_checkpoint_from_hub_subfolder(self): + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), + ] + ) + def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained( - "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet" - ) + loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) @@ -1080,20 +1090,30 @@ def test_load_sharded_checkpoint_from_hub_local_subfolder(self): assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_gpu - def test_load_sharded_checkpoint_device_map_from_hub(self): + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), + ] + ) + def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy", device_map="auto") + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto") new_output = loaded_model(**inputs_dict) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_gpu - def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self): + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), + ] + ) + def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained( - "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto" - ) + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto") new_output = loaded_model(**inputs_dict) assert loaded_model @@ -1121,18 +1141,6 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) - @require_torch_gpu - def test_load_sharded_checkpoint_with_variant_from_hub(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained( - "hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16" - ) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - @require_peft_backend def test_lora(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index c73a12a4cbf8..8b087db6726e 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -30,6 +30,7 @@ import safetensors.torch import torch import torch.nn as nn +from huggingface_hub import snapshot_download from parameterized import parameterized from PIL import Image from requests.exceptions import HTTPError @@ -551,6 +552,50 @@ def test_download_variant_partly(self): assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 assert not any(f.endswith(other_format) for f in files) + def test_download_variants_with_sharded_checkpoints(self): + # Here we test for downloading of "variant" files belonging to the `unet` and + # the `text_encoder`. Their checkpoints can be sharded. + for use_safetensors in [True, False]: + for variant in ["fp16", None]: + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe-variants-right-format", + safety_checker=None, + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # Check for `model_ext` and `variant`. + model_ext = ".safetensors" if use_safetensors else ".bin" + unexpected_ext = ".bin" if use_safetensors else ".safetensors" + model_files = [f for f in files if f.endswith(model_ext)] + assert not any(f.endswith(unexpected_ext) for f in files) + assert all(variant in f for f in model_files if f.endswith(model_ext) and variant is not None) + + def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self): + repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds" + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant" + + for is_local in [True, False]: + with CaptureLogger(logger) as cap_logger: + with tempfile.TemporaryDirectory() as tmpdirname: + local_repo_id = repo_id + if is_local: + local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname) + + _ = DiffusionPipeline.from_pretrained( + local_repo_id, + safety_checker=None, + variant="fp16", + use_safetensors=True, + ) + assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs" + def test_download_safetensors_only_variant_exists_for_model(self): variant = None use_safetensors = True @@ -655,7 +700,7 @@ def test_local_save_load_index(self): out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) + pipe.save_pretrained(tmpdirname, variant=variant, safe_serialization=use_safe) pipe_2 = StableDiffusionPipeline.from_pretrained( tmpdirname, safe_serialization=use_safe, variant=variant ) @@ -1646,7 +1691,7 @@ def test_name_or_path(self): def test_error_no_variant_available(self): variant = "fp16" with self.assertRaises(ValueError) as error_context: - _ = StableDiffusionPipeline.download( + _ = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant ) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 49da08e2ca45..3e6f9d1278e8 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1824,6 +1824,74 @@ def callback_increase_guidance(pipe, i, t, callback_kwargs): # accounts for models that modify the number of inference steps based on strength assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps) + def test_serialization_with_variants(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + model_components = [ + component_name for component_name, component in pipe.components.items() if isinstance(component, nn.Module) + ] + variant = "fp16" + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) + + with open(f"{tmpdir}/model_index.json", "r") as f: + config = json.load(f) + + for subfolder in os.listdir(tmpdir): + if not os.path.isfile(subfolder) and subfolder in model_components: + folder_path = os.path.join(tmpdir, subfolder) + is_folder = os.path.isdir(folder_path) and subfolder in config + assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) + + def test_loading_with_variants(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + variant = "fp16" + + def is_nan(tensor): + if tensor.ndimension() == 0: + has_nan = torch.isnan(tensor).item() + else: + has_nan = torch.isnan(tensor).any() + return has_nan + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant) + + model_components_pipe = { + component_name: component + for component_name, component in pipe.components.items() + if isinstance(component, nn.Module) + } + model_components_pipe_loaded = { + component_name: component + for component_name, component in pipe_loaded.components.items() + if isinstance(component, nn.Module) + } + for component_name in model_components_pipe: + pipe_component = model_components_pipe[component_name] + pipe_loaded_component = model_components_pipe_loaded[component_name] + for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()): + # nan check for luminanext (mps). + if not (is_nan(p1) and is_nan(p2)): + self.assertTrue(torch.equal(p1, p2)) + + def test_loading_with_incorrect_variants_raises_error(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + variant = "fp16" + + with tempfile.TemporaryDirectory() as tmpdir: + # Don't save with variants. + pipe.save_pretrained(tmpdir, safe_serialization=False) + + with self.assertRaises(ValueError) as error: + _ = self.pipeline_class.from_pretrained(tmpdir, variant=variant) + + assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception) + def test_StableDiffusionMixin_component(self): """Any pipeline that have LDMFuncMixin should have vae and unet components.""" if not issubclass(self.pipeline_class, StableDiffusionMixin): From bd4df2856ae399dc55e4ded57164eff8ecf6cb65 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 28 Sep 2024 17:09:30 +0530 Subject: [PATCH 28/35] [refactor] remove conv_cache from CogVideoX VAE (#9524) * remove conv cache from the layer and pass as arg instead * make style * yiyi's cleaner implementation Co-Authored-By: YiYi Xu * sayak's compiled implementation Co-Authored-By: Sayak Paul --------- Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- .../autoencoders/autoencoder_kl_cogvideox.py | 265 ++++++++++++------ 1 file changed, 181 insertions(+), 84 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 04c787ee3e84..a91180b11825 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import torch @@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d): """ def forward(self, input: torch.Tensor) -> torch.Tensor: - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + memory_count = ( + (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3 + ) # Set to 2GB, suitable for CuDNN if memory_count > 2: @@ -115,34 +117,24 @@ def __init__( dilation=dilation, ) - self.conv_cache = None - - def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: + def fake_context_parallel_forward( + self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None + ) -> torch.Tensor: kernel_size = self.time_kernel_size if kernel_size > 1: - cached_inputs = ( - [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) - ) + cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) inputs = torch.cat(cached_inputs + [inputs], dim=2) return inputs - def _clear_fake_context_parallel_cache(self): - del self.conv_cache - self.conv_cache = None - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - inputs = self.fake_context_parallel_forward(inputs) - - self._clear_fake_context_parallel_cache() - # Note: we could move these to the cpu for a lower maximum memory usage but its only a few - # hundred megabytes and so let's not do it for now - self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: + inputs = self.fake_context_parallel_forward(inputs, conv_cache) + conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) inputs = F.pad(inputs, padding_2d, mode="constant", value=0) output = self.conv(inputs) - return output + return output, conv_cache class CogVideoXSpatialNorm3D(nn.Module): @@ -172,7 +164,12 @@ def __init__( self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + def forward( + self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None + ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + if f.shape[2] > 1 and f.shape[2] % 2 == 1: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] @@ -183,9 +180,12 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: else: zq = F.interpolate(zq, size=f.shape[-3:]) + conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f + new_f = norm_f * conv_y + conv_b + return new_f, new_conv_cache class CogVideoXResnetBlock3D(nn.Module): @@ -236,6 +236,7 @@ def __init__( self.out_channels = out_channels self.nonlinearity = get_activation(non_linearity) self.use_conv_shortcut = conv_shortcut + self.spatial_norm_dim = spatial_norm_dim if spatial_norm_dim is None: self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) @@ -279,34 +280,43 @@ def forward( inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + hidden_states = inputs if zq is not None: - hidden_states = self.norm1(hidden_states, zq) + hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1")) else: hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states) + hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) if temb is not None: hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] if zq is not None: - hidden_states = self.norm2(hidden_states, zq) + hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2")) else: hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) if self.in_channels != self.out_channels: - inputs = self.conv_shortcut(inputs) + if self.use_conv_shortcut: + inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut( + inputs, conv_cache=conv_cache.get("conv_shortcut") + ) + else: + inputs = self.conv_shortcut(inputs) hidden_states = hidden_states + inputs - return hidden_states + return hidden_states, new_conv_cache class CogVideoXDownBlock3D(nn.Module): @@ -392,8 +402,16 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: - for resnet in self.resnets: + r"""Forward method of the `CogVideoXDownBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -402,17 +420,23 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + zq, + conv_cache=conv_cache.get(conv_cache_key), ) else: - hidden_states = resnet(hidden_states, temb, zq) + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - return hidden_states + return hidden_states, new_conv_cache class CogVideoXMidBlock3D(nn.Module): @@ -480,8 +504,16 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: - for resnet in self.resnets: + r"""Forward method of the `CogVideoXMidBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -490,13 +522,15 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) ) else: - hidden_states = resnet(hidden_states, temb, zq) + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) - return hidden_states + return hidden_states, new_conv_cache class CogVideoXUpBlock3D(nn.Module): @@ -584,9 +618,16 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: r"""Forward method of the `CogVideoXUpBlock3D` class.""" - for resnet in self.resnets: + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -595,17 +636,23 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + zq, + conv_cache=conv_cache.get(conv_cache_key), ) else: - hidden_states = resnet(hidden_states, temb, zq) + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) - return hidden_states + return hidden_states, new_conv_cache class CogVideoXEncoder3D(nn.Module): @@ -705,9 +752,18 @@ def __init__( self.gradient_checkpointing = False - def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: r"""The forward method of the `CogVideoXEncoder3D` class.""" - hidden_states = self.conv_in(sample) + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) if self.training and self.gradient_checkpointing: @@ -718,28 +774,44 @@ def custom_forward(*inputs): return custom_forward # 1. Down - for down_block in self.down_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), hidden_states, temb, None + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), + hidden_states, + temb, + None, + conv_cache=conv_cache.get(conv_cache_key), ) # 2. Mid - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb, None + hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + hidden_states, + temb, + None, + conv_cache=conv_cache.get("mid_block"), ) else: # 1. Down - for down_block in self.down_blocks: - hidden_states = down_block(hidden_states, temb, None) + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = down_block( + hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key) + ) # 2. Mid - hidden_states = self.mid_block(hidden_states, temb, None) + hidden_states, new_conv_cache["mid_block"] = self.mid_block( + hidden_states, temb, None, conv_cache=conv_cache.get("mid_block") + ) # 3. Post-process hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) - return hidden_states + + hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out")) + + return hidden_states, new_conv_cache class CogVideoXDecoder3D(nn.Module): @@ -846,9 +918,18 @@ def __init__( self.gradient_checkpointing = False - def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: r"""The forward method of the `CogVideoXDecoder3D` class.""" - hidden_states = self.conv_in(sample) + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) if self.training and self.gradient_checkpointing: @@ -859,28 +940,45 @@ def custom_forward(*inputs): return custom_forward # 1. Mid - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb, sample + hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + hidden_states, + temb, + sample, + conv_cache=conv_cache.get("mid_block"), ) # 2. Up - for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), hidden_states, temb, sample + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + hidden_states, + temb, + sample, + conv_cache=conv_cache.get(conv_cache_key), ) else: # 1. Mid - hidden_states = self.mid_block(hidden_states, temb, sample) + hidden_states, new_conv_cache["mid_block"] = self.mid_block( + hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block") + ) # 2. Up - for up_block in self.up_blocks: - hidden_states = up_block(hidden_states, temb, sample) + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = up_block( + hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key) + ) # 3. Post-process - hidden_states = self.norm_out(hidden_states, sample) + hidden_states, new_conv_cache["norm_out"] = self.norm_out( + hidden_states, sample, conv_cache=conv_cache.get("norm_out") + ) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) - return hidden_states + hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out")) + + return hidden_states, new_conv_cache class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): @@ -1019,12 +1117,6 @@ def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): module.gradient_checkpointing = value - def _clear_fake_context_parallel_cache(self): - for name, module in self.named_modules(): - if isinstance(module, CogVideoXCausalConv3d): - logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") - module._clear_fake_context_parallel_cache() - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, @@ -1091,20 +1183,20 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: frame_batch_size = self.num_sample_frames_batch_size # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + conv_cache = None enc = [] + for i in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) end_frame = frame_batch_size * (i + 1) + remaining_frames x_intermediate = x[:, :, start_frame:end_frame] - x_intermediate = self.encoder(x_intermediate) + x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache) if self.quant_conv is not None: x_intermediate = self.quant_conv(x_intermediate) enc.append(x_intermediate) - self._clear_fake_context_parallel_cache() enc = torch.cat(enc, dim=2) - return enc @apply_forward_hook @@ -1143,7 +1235,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut frame_batch_size = self.num_latent_frames_batch_size num_batches = num_frames // frame_batch_size + conv_cache = None dec = [] + for i in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) @@ -1151,10 +1245,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut z_intermediate = z[:, :, start_frame:end_frame] if self.post_quant_conv is not None: z_intermediate = self.post_quant_conv(z_intermediate) - z_intermediate = self.decoder(z_intermediate) + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) dec.append(z_intermediate) - self._clear_fake_context_parallel_cache() dec = torch.cat(dec, dim=2) if not return_dict: @@ -1238,7 +1331,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: for j in range(0, width, overlap_width): # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + conv_cache = None time = [] + for k in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) @@ -1250,11 +1345,11 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width, ] - tile = self.encoder(tile) + tile, conv_cache = self.encoder(tile, conv_cache=conv_cache) if self.quant_conv is not None: tile = self.quant_conv(tile) time.append(tile) - self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) rows.append(row) @@ -1315,7 +1410,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod row = [] for j in range(0, width, overlap_width): num_batches = num_frames // frame_batch_size + conv_cache = None time = [] + for k in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) @@ -1329,9 +1426,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod ] if self.post_quant_conv is not None: tile = self.post_quant_conv(tile) - tile = self.decoder(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) time.append(tile) - self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) rows.append(row) From b28675c605a89c3da94e9792bec91b32209d191b Mon Sep 17 00:00:00 2001 From: Anand Kumar <63339285+AnandK27@users.noreply.github.com> Date: Sat, 28 Sep 2024 08:31:37 -0700 Subject: [PATCH 29/35] [train_instruct_pix2pix.py]Fix the LR schedulers when `num_train_epochs` is passed in a distributed training env (#9316) Fixed pix2pix lr scheduler Co-authored-by: Sayak Paul --- .../train_instruct_pix2pix.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index e5b7eaac4a1f..3cb0c6702599 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -747,17 +747,22 @@ def collate_fn(examples): ) # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, ) # Prepare everything with our `accelerator`. @@ -782,8 +787,14 @@ def collate_fn(examples): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: + if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) From 8e7d6c03a366fdb0f551ce7b92f0871c863d4e08 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 28 Sep 2024 21:08:45 +0530 Subject: [PATCH 30/35] [chore] fix: retain memory utility. (#9543) * fix: retain memory utility. * fix * quality * free_memory. --- examples/cogvideo/train_cogvideox_lora.py | 8 +++----- examples/controlnet/train_controlnet_flux.py | 8 +++++--- examples/controlnet/train_controlnet_sd3.py | 13 ++++++------ .../dreambooth/train_dreambooth_lora_flux.py | 11 ++++++---- .../dreambooth/train_dreambooth_lora_sd3.py | 20 +++++++++---------- src/diffusers/training_utils.py | 8 ++------ 6 files changed, 33 insertions(+), 35 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 137f3222f6d9..6787c37f93a8 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -38,10 +38,7 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid -from diffusers.training_utils import ( - cast_training_params, - clear_objs_and_retain_memory, -) +from diffusers.training_utils import cast_training_params, free_memory from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -726,7 +723,8 @@ def log_validation( } ) - clear_objs_and_retain_memory([pipe]) + del pipe + free_memory() return videos diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index e344a9b1e2a5..5969218f3c3e 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -54,7 +54,7 @@ from diffusers.models.controlnet_flux import FluxControlNetModel from diffusers.optimization import get_scheduler from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline -from diffusers.training_utils import clear_objs_and_retain_memory, compute_density_for_timestep_sampling +from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available @@ -193,7 +193,8 @@ def log_validation( else: logger.warning(f"image logging not implemented for {tracker.name}") - clear_objs_and_retain_memory([pipeline]) + del pipeline + free_memory() return image_logs @@ -1103,7 +1104,8 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50 ) - clear_objs_and_retain_memory([text_encoders, tokenizers]) + del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + free_memory() # Then get the training dataset ready to be passed to the dataloader. train_dataset = prepare_train_dataset(train_dataset, accelerator) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 4b255c501d99..9ea78370f5e0 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -49,11 +49,7 @@ StableDiffusion3ControlNetPipeline, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import ( - clear_objs_and_retain_memory, - compute_density_for_timestep_sampling, - compute_loss_weighting_for_sd3, -) +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v else: logger.warning(f"image logging not implemented for {tracker.name}") - clear_objs_and_retain_memory(pipeline) + del pipeline + free_memory() if not is_final_validation: controlnet.to(accelerator.device) @@ -1131,7 +1128,9 @@ def compute_text_embeddings(batch, text_encoders, tokenizers): new_fingerprint = Hasher.hash(args) train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) - clear_objs_and_retain_memory(text_encoders + tokenizers) + del text_encoder_one, text_encoder_two, text_encoder_three + del tokenizer_one, tokenizer_two, tokenizer_three + free_memory() train_dataloader = torch.utils.data.DataLoader( train_dataset, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6091622719ee..fcc11386abcf 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -55,9 +55,9 @@ from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, - clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + free_memory, ) from diffusers.utils import ( check_min_version, @@ -1437,7 +1437,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two]) + del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1480,7 +1481,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) if args.validation_prompt is None: - clear_objs_and_retain_memory([vae]) + del vae + free_memory() # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1817,7 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_dtype=weight_dtype, ) if not args.train_text_encoder: - clear_objs_and_retain_memory([text_encoder_one, text_encoder_two]) + del text_encoder_one, text_encoder_two + free_memory() # Save the lora layers accelerator.wait_for_everyone() diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 3060813bbbdc..02f5a7ee0f7a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -55,9 +55,9 @@ from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, - clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + free_memory, ) from diffusers.utils import ( check_min_version, @@ -211,7 +211,8 @@ def log_validation( } ) - clear_objs_and_retain_memory(objs=[pipeline]) + del pipeline + free_memory() return images @@ -1106,7 +1107,8 @@ def main(args): image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) - clear_objs_and_retain_memory(objs=[pipeline]) + del pipeline + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1453,9 +1455,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - clear_objs_and_retain_memory( - objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] - ) + del tokenizers, text_encoders + del text_encoder_one, text_encoder_two, text_encoder_three + free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1791,11 +1793,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) - objs = [] - if not args.train_text_encoder: - objs.extend([text_encoder_one, text_encoder_two, text_encoder_three]) - clear_objs_and_retain_memory(objs=objs) + del text_encoder_one, text_encoder_two, text_encoder_three + free_memory() # Save the lora layers accelerator.wait_for_everyone() diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 26d4a2a504c6..57bd9074870c 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def clear_objs_and_retain_memory(objs: List[Any]): - """Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator.""" - if len(objs) >= 1: - for obj in objs: - del obj - +def free_memory(): + """Runs garbage collection. Then clears the cache of the available accelerator.""" gc.collect() if torch.cuda.is_available(): From f9fd511466376c7021470695a31ebb8ed8078856 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 30 Sep 2024 23:29:39 +0530 Subject: [PATCH 31/35] [LoRA] support Kohya Flux LoRAs that have text encoders as well (#9542) * support kohya flux loras that have tes. --- .../loaders/lora_conversion_utils.py | 41 ++++++++++++++++++- tests/lora/test_lora_layers_flux.py | 20 +++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index f6dea33e8e82..d829cc3a844b 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -516,10 +516,47 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): f"transformer.single_transformer_blocks.{i}.norm.linear", ) + remaining_keys = list(sds_sd.keys()) + te_state_dict = {} + if remaining_keys: + if not all(k.startswith("lora_te1") for k in remaining_keys): + raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}") + for key in remaining_keys: + if not key.endswith("lora_down.weight"): + continue + + lora_name = key.split(".")[0] + lora_name_up = f"{lora_name}.lora_up.weight" + lora_name_alpha = f"{lora_name}.alpha" + diffusers_name = _convert_text_encoder_lora_key(key, lora_name) + + if lora_name.startswith(("lora_te_", "lora_te1_")): + down_weight = sds_sd.pop(key) + sd_lora_rank = down_weight.shape[0] + te_state_dict[diffusers_name] = down_weight + te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up) + + if lora_name_alpha in sds_sd: + alpha = sds_sd.pop(lora_name_alpha).item() + scale = alpha / sd_lora_rank + + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + te_state_dict[diffusers_name] *= scale_down + te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up + if len(sds_sd) > 0: - logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}") + + if te_state_dict: + te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()} - return ait_sd + new_state_dict = {**ait_sd, **te_state_dict} + return new_state_dict return _convert_sd_scripts_to_ai_toolkit(state_dict) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 0c336ebc3cbf..a75f9df91047 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -228,6 +228,26 @@ def test_flux_kohya(self): assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) + def test_flux_kohya_with_text_encoder(self): + self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + self.pipeline.enable_model_cpu_offload() + + prompt = "optimus is cleaning the house with broomstick" + out = self.pipeline( + prompt, + num_inference_steps=self.num_inference_steps, + guidance_scale=4.5, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = out[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219]) + + assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) + def test_flux_xlabs(self): self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") self.pipeline.fuse_lora() From c4a8979f3018fbffee33304c1940561f7a5cf613 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 30 Sep 2024 20:00:54 +0100 Subject: [PATCH 32/35] Add beta sigmas to other schedulers and update docs (#9538) --- docs/source/en/api/schedulers/overview.md | 1 + .../schedulers/scheduling_deis_multistep.py | 53 +++++++++++++++++- .../scheduling_dpmsolver_multistep.py | 55 ++++++++++++++++++- .../scheduling_dpmsolver_multistep_inverse.py | 54 +++++++++++++++++- .../schedulers/scheduling_dpmsolver_sde.py | 52 +++++++++++++++++- .../scheduling_dpmsolver_singlestep.py | 54 +++++++++++++++++- .../schedulers/scheduling_heun_discrete.py | 54 +++++++++++++++++- .../scheduling_k_dpm_2_ancestral_discrete.py | 52 +++++++++++++++++- .../schedulers/scheduling_k_dpm_2_discrete.py | 52 +++++++++++++++++- .../schedulers/scheduling_lms_discrete.py | 46 +++++++++++++++- .../schedulers/scheduling_sasolver.py | 53 +++++++++++++++++- .../schedulers/scheduling_unipc_multistep.py | 53 +++++++++++++++++- 12 files changed, 551 insertions(+), 28 deletions(-) diff --git a/docs/source/en/api/schedulers/overview.md b/docs/source/en/api/schedulers/overview.md index 2150357cc2b9..af287454e15d 100644 --- a/docs/source/en/api/schedulers/overview.md +++ b/docs/source/en/api/schedulers/overview.md @@ -52,6 +52,7 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso | sgm_uniform | init with `timestep_spacing="trailing"` | | simple | init with `timestep_spacing="trailing"` | | exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` | +| beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` | All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers. diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 3b26befac64e..6fe8474aab87 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -22,10 +22,14 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import deprecate +from ..utils import deprecate, is_scipy_available from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -113,6 +117,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -141,11 +148,16 @@ def __init__( lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -263,6 +275,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 @@ -396,6 +411,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def convert_model_output( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 924eefb0e98d..7677e37e9426 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,11 +21,15 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import deprecate +from ..utils import deprecate, is_scipy_available from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -163,6 +167,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. use_lu_lambdas (`bool`, *optional*, defaults to `False`): Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of @@ -209,6 +216,7 @@ def __init__( euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, use_lu_lambdas: Optional[bool] = False, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), @@ -217,8 +225,12 @@ def __init__( steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) @@ -337,6 +349,8 @@ def set_timesteps( raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") if timesteps is not None and self.config.use_exponential_sigmas: raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") + if timesteps is not None and self.config.use_beta_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.") if timesteps is not None: timesteps = np.array(timesteps).astype(np.int64) @@ -388,6 +402,9 @@ def set_timesteps( elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -542,6 +559,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def convert_model_output( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 4f024b8c4c75..c26a464518f0 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -21,11 +21,15 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import deprecate +from ..utils import deprecate, is_scipy_available from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -126,6 +130,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -161,13 +168,18 @@ def __init__( euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) @@ -219,6 +231,7 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.use_karras_sigmas = use_karras_sigmas self.use_exponential_sigmas = use_exponential_sigmas + self.use_beta_sigmas = use_beta_sigmas @property def step_index(self): @@ -276,6 +289,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_max = ( @@ -416,6 +432,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output def convert_model_output( self, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 3748de63388a..610e8d2d765c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -20,9 +20,14 @@ import torchsde from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import is_scipy_available from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + + class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" @@ -162,6 +167,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed is generated. timestep_spacing (`str`, defaults to `"linspace"`): @@ -185,12 +193,17 @@ def __init__( prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -349,6 +362,9 @@ def set_timesteps( elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) @@ -451,6 +467,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + @property def state_in_first_order(self): return self.sample is None diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 353baf08e81d..3329919cfb02 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,11 +21,14 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import deprecate, logging +from ..utils import deprecate, is_scipy_available, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -125,6 +128,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. final_sigmas_type (`str`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. @@ -157,12 +163,17 @@ def __init__( lower_order_final: bool = False, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if algorithm_type == "dpmsolver": deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) @@ -307,6 +318,8 @@ def set_timesteps( raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.") if timesteps is not None and self.config.use_exponential_sigmas: raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") + if timesteps is not None and self.config.use_beta_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.") num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps @@ -333,6 +346,9 @@ def set_timesteps( elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -484,6 +500,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def convert_model_output( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index efcfdeb1d5ef..cb995df4af59 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -19,9 +19,14 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import is_scipy_available from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -99,6 +104,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -120,13 +128,18 @@ def __init__( prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, clip_sample: Optional[bool] = False, clip_sample_range: float = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -258,6 +271,8 @@ def set_timesteps( raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") if timesteps is not None and self.config.use_exponential_sigmas: raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") + if timesteps is not None and self.config.use_beta_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.") num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps @@ -296,6 +311,9 @@ def set_timesteps( elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) @@ -386,6 +404,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + @property def state_in_first_order(self): return self.dt is None diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 038aa19603ea..b1ec244e5a79 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -19,10 +19,15 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import is_scipy_available from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -93,6 +98,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -117,12 +125,17 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -258,6 +271,9 @@ def set_timesteps( elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) self.log_sigmas = torch.from_numpy(log_sigmas).to(device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -376,6 +392,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + @property def state_in_first_order(self): return self.sample is None diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 8fbf66832668..422fe40556f0 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -19,9 +19,14 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import is_scipy_available from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -92,6 +97,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -116,12 +124,17 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -257,6 +270,9 @@ def set_timesteps( elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -389,6 +405,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def step( self, model_output: Union[torch.Tensor, np.ndarray], diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 5ef8ffb0dcbf..aed8c5828c75 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -17,6 +17,7 @@ from typing import List, Optional, Tuple, Union import numpy as np +import scipy.stats import torch from scipy import integrate @@ -113,6 +114,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -137,12 +141,15 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -297,6 +304,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -392,6 +402,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def step( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index ad79c69fc714..7188be5caaea 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -22,11 +22,15 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import deprecate +from ..utils import deprecate, is_scipy_available from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -124,6 +128,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -159,13 +166,18 @@ def __init__( lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -292,6 +304,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 @@ -425,6 +440,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def convert_model_output( self, model_output: torch.Tensor, diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 78cf0b6d16a7..195e9c8477a2 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -22,10 +22,14 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import deprecate +from ..utils import deprecate, is_scipy_available from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +if is_scipy_available(): + import scipy.stats + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -161,6 +165,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -198,13 +205,18 @@ def __init__( solver_p: SchedulerMixin = None, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" rescale_betas_zero_snr: bool = False, ): - if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -337,6 +349,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": @@ -480,6 +495,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.Tensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def convert_model_output( self, model_output: torch.Tensor, From 33fafe3d143ca8380a9e405e7acfa69091d863fb Mon Sep 17 00:00:00 2001 From: JuanCarlosPi Date: Tue, 1 Oct 2024 01:04:42 -0500 Subject: [PATCH 33/35] Add PAG support to StableDiffusionControlNetPAGInpaintPipeline (#8875) * Add pag to controlnet inpainting pipeline --------- Co-authored-by: YiYi Xu --- docs/source/en/api/pipelines/pag.md | 3 + src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/auto_pipeline.py | 2 + src/diffusers/pipelines/pag/__init__.py | 2 + .../pag/pipeline_pag_controlnet_sd_inpaint.py | 1544 +++++++++++++++++ .../pag/pipeline_pag_controlnet_sd_xl.py | 2 +- .../dummy_torch_and_transformers_objects.py | 15 + .../pag/test_pag_controlnet_sd_inpaint.py | 245 +++ 9 files changed, 1816 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py create mode 100644 tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md index aa69598ae290..8e3c82ea9e27 100644 --- a/docs/source/en/api/pipelines/pag.md +++ b/docs/source/en/api/pipelines/pag.md @@ -55,6 +55,9 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial ## StableDiffusionControlNetPAGPipeline [[autodoc]] StableDiffusionControlNetPAGPipeline + +## StableDiffusionControlNetPAGInpaintPipeline +[[autodoc]] StableDiffusionControlNetPAGInpaintPipeline - all - __call__ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index dedb6f5c7f14..4214a4699ec8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -328,6 +328,7 @@ "StableDiffusionAttendAndExcitePipeline", "StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPAGInpaintPipeline", "StableDiffusionControlNetPAGPipeline", "StableDiffusionControlNetPipeline", "StableDiffusionControlNetXSPipeline", @@ -778,6 +779,7 @@ StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPAGInpaintPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionControlNetPipeline, StableDiffusionControlNetXSPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ddab5122d870..3b6cde17c8a3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -158,6 +158,7 @@ ) _import_structure["pag"].extend( [ + "StableDiffusionControlNetPAGInpaintPipeline", "AnimateDiffPAGPipeline", "KolorsPAGPipeline", "HunyuanDiTPAGPipeline", @@ -566,6 +567,7 @@ KolorsPAGPipeline, PixArtSigmaPAGPipeline, StableDiffusion3PAGPipeline, + StableDiffusionControlNetPAGInpaintPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, StableDiffusionXLControlNetPAGImg2ImgPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index f6186da260ad..e3e78d0663fa 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -61,6 +61,7 @@ HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, StableDiffusion3PAGPipeline, + StableDiffusionControlNetPAGInpaintPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, StableDiffusionXLControlNetPAGImg2ImgPipeline, @@ -148,6 +149,7 @@ ("kandinsky", KandinskyInpaintCombinedPipeline), ("kandinsky22", KandinskyV22InpaintCombinedPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), + ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("flux", FluxInpaintPipeline), diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index d8842ce91175..a7ceb7e296d5 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] + _import_structure["pipeline_pag_controlnet_sd_inpaint"] = ["StableDiffusionControlNetPAGInpaintPipeline"] _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] _import_structure["pipeline_pag_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetPAGImg2ImgPipeline"] _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] @@ -44,6 +45,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline + from .pipeline_pag_controlnet_sd_inpaint import StableDiffusionControlNetPAGInpaintPipeline from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline from .pipeline_pag_controlnet_sd_xl_img2img import StableDiffusionXLControlNetPAGImg2ImgPipeline from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py new file mode 100644 index 000000000000..f5f117ab7625 --- /dev/null +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py @@ -0,0 +1,1544 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ..controlnet.multicontrolnet import MultiControlNetModel +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> import cv2 + >>> from diffusers import AutoPipelineForInpainting, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> from PIL import Image + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_canny_condition(image): + ... image = np.array(image) + ... image = cv2.Canny(image, 100, 200) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... image = Image.fromarray(image) + ... return image + + + >>> control_image = make_canny_condition(init_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = AutoPipelineForInpainting.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... pag_scale=0.3, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionControlNetPAGInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for image inpainting using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + + + This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting + ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as + default text-to-image Stable Diffusion checkpoints + ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image + Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as + [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: Union[str, List[str]] = "mid", + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + mask_image, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint.StableDiffusionControlNetInpaintPipeline.prepare_control_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords, + resize_mode, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.5, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, + `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to be used as the starting point. For both + NumPy array and PyTorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a NumPy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, + `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a NumPy array or PyTorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for PyTorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for NumPy array, it would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, + W, 1)`, or `(H, W)`. + control_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, + `List[List[torch.Tensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + mask_image, + height, + width, + output_type, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare control image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask < 0.5) + _, _, height, width = init_image.shape + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 7.1 Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 7.2 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.3 Prepare embeddings + # ip-adapter + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # control image + control_images = control_image if isinstance(control_image, list) else [control_image] + for i, single_control_image in enumerate(control_images): + if self.do_classifier_free_guidance: + single_control_image = single_control_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_control_image = self._prepare_perturbed_attention_guidance( + single_control_image, single_control_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_control_image = torch.cat([single_control_image] * 2) + single_control_image = single_control_image.to(device) + control_images[i] = single_control_image + + control_image = control_images if isinstance(control_image, list) else control_images[0] + controlnet_prompt_embeds = prompt_embeds + + # 7.4 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=False, + return_dict=False, + ) + + # concat latents, mask, masked_image_latents in the channel dimension + if num_channels_unet == 9: + first_dim_size = latent_model_input.shape[0] + # Ensure mask and masked_image_latents have the right dimensions + if mask.shape[0] < first_dim_size: + repeat_factor = (first_dim_size + mask.shape[0] - 1) // mask.shape[0] + mask = mask.repeat(repeat_factor, 1, 1, 1)[:first_dim_size] + if masked_image_latents.shape[0] < first_dim_size: + repeat_factor = ( + first_dim_size + masked_image_latents.shape[0] - 1 + ) // masked_image_latents.shape[0] + masked_image_latents = masked_image_latents.repeat(repeat_factor, 1, 1, 1)[:first_dim_size] + # Perform the concatenation + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index 247fc900a7b0..d19d9adc89c6 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -1347,7 +1347,7 @@ def __call__( latents, ) - # 6.5 Optionally get Guidance Scale Embedding + # 6.1 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 4f22501ce7ec..1927fc8cd4d3 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1352,6 +1352,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionControlNetPAGInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionControlNetPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py new file mode 100644 index 000000000000..0a7413e99926 --- /dev/null +++ b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily based on: + +import inspect +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDIMScheduler, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPAGInpaintPipeline, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, +) +from diffusers.utils.torch_utils import randn_tensor + +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, +) +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin + + +enable_full_determinism() + + +class StableDiffusionControlNetPAGInpaintPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): + pipeline_class = StableDiffusionControlNetPAGInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset({"control_image"}) # skip `image` and `mask` for now, only test for control_image + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + def get_dummy_components(self): + # Copied from tests.pipelines.controlnet.test_controlnet_inpaint.ControlNetInpaintPipelineFastTests.get_dummy_components + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + control_image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + init_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + init_image = init_image.cpu().permute(0, 2, 3, 1)[0] + + image = Image.fromarray(np.uint8(init_image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(init_image + 4)).convert("RGB").resize((64, 64)) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "pag_scale": 3.0, + "output_type": "np", + "image": image, + "mask_image": mask_image, + "control_image": control_image, + } + + return inputs + + def test_pag_disable_enable(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline (expect same output when pag is disabled) + pipe_sd = StableDiffusionControlNetInpaintPipeline(**components) + pipe_sd = pipe_sd.to(device) + pipe_sd.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["pag_scale"] + assert ( + "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}." + out = pipe_sd(**inputs).images[0, -3:, -3:, -1] + + # pag disabled with pag_scale=0.0 + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 0.0 + out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + # pag enabled + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 + assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 + + def test_pag_cfg(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe_pag(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == ( + 1, + 64, + 64, + 3, + ), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}" + expected_slice = np.array( + [0.7488756, 0.61194265, 0.53382546, 0.5993959, 0.6193306, 0.56880975, 0.41277143, 0.5050145, 0.49376273] + ) + + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}" + + def test_pag_uncond(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["guidance_scale"] = 0.0 + image = pipe_pag(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == ( + 1, + 64, + 64, + 3, + ), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}" + expected_slice = np.array( + [0.7410303, 0.5989337, 0.530866, 0.60571927, 0.6162597, 0.5719856, 0.4187478, 0.5101238, 0.4978468] + ) + + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}" From 61d37640ade33f1e2d330a51466b94dfe4f155f6 Mon Sep 17 00:00:00 2001 From: Darren Hsu <35377472+darhsu@users.noreply.github.com> Date: Tue, 1 Oct 2024 19:08:12 -0700 Subject: [PATCH 34/35] Support bfloat16 for Upsample2D (#9480) * Support bfloat16 for Upsample2D * Add test and use is_torch_version * Resolve comments and add decorator * Simplify require_torch_version_greater_equal decorator * Run make style --------- Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- src/diffusers/models/upsampling.py | 12 ++++++------ src/diffusers/utils/testing_utils.py | 12 ++++++++++++ tests/models/test_layers_utils.py | 16 ++++++++++++++++ 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index fd5ed28c7070..cf07e45b0c5c 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from ..utils import deprecate +from ..utils.import_utils import is_torch_version from .normalization import RMSNorm @@ -151,11 +152,10 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None if self.use_conv_transpose: return self.conv(hidden_states) - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1 + # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767 dtype = hidden_states.dtype - if dtype == torch.bfloat16: + if dtype == torch.bfloat16 and is_torch_version("<", "2.1"): hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 @@ -170,8 +170,8 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: + # Cast back to original dtype + if dtype == torch.bfloat16 and is_torch_version("<", "2.1"): hidden_states = hidden_states.to(dtype) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index be3e9983c80f..7dc3f414d55c 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -252,6 +252,18 @@ def require_torch_2(test_case): ) +def require_torch_version_greater_equal(torch_version): + """Decorator marking a test that requires torch with a specific version or greater.""" + + def decorator(test_case): + correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version) + return unittest.skipUnless( + correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}" + )(test_case) + + return decorator + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index 66e142f8c66a..415bb12b73c6 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -27,6 +27,7 @@ from diffusers.utils.testing_utils import ( backend_manual_seed, require_torch_accelerator_with_fp64, + require_torch_version_greater_equal, torch_device, ) @@ -120,6 +121,21 @@ def test_upsample_default(self): expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254]) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + @require_torch_version_greater_equal("2.1") + def test_upsample_bfloat16(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32).to(torch.bfloat16) + upsample = Upsample2D(channels=32, use_conv=False) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 32, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor( + [-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254], dtype=torch.bfloat16 + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + def test_upsample_with_conv(self): torch.manual_seed(0) sample = torch.randn(1, 32, 32, 32) From 7f323f0f3190533e596e09a4923dad1f73f23a91 Mon Sep 17 00:00:00 2001 From: Xiangchendong <66510463+Xiang-cd@users.noreply.github.com> Date: Thu, 3 Oct 2024 03:07:06 +0800 Subject: [PATCH 35/35] fix cogvideox autoencoder decode (#9569) Co-authored-by: Aryan --- src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index a91180b11825..7834206ddb4a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -1234,7 +1234,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut return self.tiled_decode(z, return_dict=return_dict) frame_batch_size = self.num_latent_frames_batch_size - num_batches = num_frames // frame_batch_size + num_batches = max(num_frames // frame_batch_size, 1) conv_cache = None dec = []