From 9e9ed353a2859550815c71b45e801efbccc350c2 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sat, 6 Jul 2024 11:32:04 -1000 Subject: [PATCH 1/2] fix loading sharded checkpoints from subfolder (#8798) * fix load sharded checkpoints from subfolder{ * style * os.path.join * add a small test --------- Co-authored-by: sayakpaul --- src/diffusers/models/model_loading_utils.py | 2 +- src/diffusers/utils/hub_utils.py | 7 ++++++- tests/models/unets/test_models_unet_2d_condition.py | 12 ++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5604879f40ab..ebd356d981d6 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -221,7 +221,7 @@ def _fetch_index_file( local_files_only=local_files_only, token=token, revision=revision, - subfolder=subfolder, + subfolder=None, user_agent=user_agent, commit_hash=commit_hash, ) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ce90fb09193b..7ecb7de89cd3 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -455,10 +455,13 @@ def _get_checkpoint_shard_files( # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames + if subfolder is not None: + allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] + ignore_patterns = ["*.json", "*.md"] if not local_files_only: # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path) + model_files_info = model_info(pretrained_model_name_or_path, revision=revision) for shard_file in original_shard_filenames: shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) if not shard_file_present: @@ -481,6 +484,8 @@ def _get_checkpoint_shard_files( ignore_patterns=ignore_patterns, user_agent=user_agent, ) + if subfolder is not None: + cached_folder = os.path.join(cached_folder, subfolder) # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # we don't have to catch them here. We have also dealt with EntryNotFoundError. diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 63e66dabf0c8..a84968e613b5 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1045,6 +1045,18 @@ def test_load_sharded_checkpoint_from_hub(self): assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu + def test_load_sharded_checkpoint_from_hub_subfolder(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet" + ) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu def test_load_sharded_checkpoint_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() From 98388670d293a590e69d9ff6b442169829f47cf3 Mon Sep 17 00:00:00 2001 From: PommesPeter <54879512+PommesPeter@users.noreply.github.com> Date: Mon, 8 Jul 2024 11:12:09 +0800 Subject: [PATCH 2/2] [Alpha-VLLM Team] Add Lumina-T2X to diffusers (#8652) --------- Co-authored-by: zhuole1025 Co-authored-by: YiYi Xu --- docs/source/en/_toctree.yml | 6 + docs/source/en/api/models/lumina_nextdit2d.md | 20 + docs/source/en/api/pipelines/lumina.md | 88 ++ .../schedulers/flow_match_heun_discrete.md | 18 + scripts/convert_lumina_to_diffusers.py | 142 +++ src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/attention.py | 52 +- src/diffusers/models/attention_processor.py | 110 ++- src/diffusers/models/embeddings.py | 132 ++- src/diffusers/models/normalization.py | 84 +- src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/lumina_nextdit2d.py | 340 +++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/lumina/__init__.py | 48 + .../pipelines/lumina/pipeline_lumina.py | 897 ++++++++++++++++++ src/diffusers/schedulers/__init__.py | 2 + .../scheduling_flow_match_heun_discrete.py | 321 +++++++ src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/lumina/__init__.py | 0 tests/pipelines/lumina/test_lumina_nextdit.py | 179 ++++ 22 files changed, 2478 insertions(+), 17 deletions(-) create mode 100644 docs/source/en/api/models/lumina_nextdit2d.md create mode 100644 docs/source/en/api/pipelines/lumina.md create mode 100644 docs/source/en/api/schedulers/flow_match_heun_discrete.md create mode 100644 scripts/convert_lumina_to_diffusers.py create mode 100644 src/diffusers/models/transformers/lumina_nextdit2d.py create mode 100644 src/diffusers/pipelines/lumina/__init__.py create mode 100644 src/diffusers/pipelines/lumina/pipeline_lumina.py create mode 100644 src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py create mode 100644 tests/pipelines/lumina/__init__.py create mode 100644 tests/pipelines/lumina/test_lumina_nextdit.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1becbcaab4a5..7f378a34bde9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -249,6 +249,8 @@ title: DiTTransformer2DModel - local: api/models/hunyuan_transformer2d title: HunyuanDiT2DModel + - local: api/models/lumina_nextdit2d + title: LuminaNextDiT2DModel - local: api/models/transformer_temporal title: TransformerTemporalModel - local: api/models/sd3_transformer2d @@ -324,6 +326,8 @@ title: Latent Diffusion - local: api/pipelines/ledits_pp title: LEDITS++ + - local: api/pipelines/lumina + title: Lumina-T2X - local: api/pipelines/marigold title: Marigold - local: api/pipelines/panorama @@ -435,6 +439,8 @@ title: EulerDiscreteScheduler - local: api/schedulers/flow_match_euler_discrete title: FlowMatchEulerDiscreteScheduler + - local: api/schedulers/flow_match_heun_discrete + title: FlowMatchHeunDiscreteScheduler - local: api/schedulers/heun title: HeunDiscreteScheduler - local: api/schedulers/ipndm diff --git a/docs/source/en/api/models/lumina_nextdit2d.md b/docs/source/en/api/models/lumina_nextdit2d.md new file mode 100644 index 000000000000..fe28918e2b58 --- /dev/null +++ b/docs/source/en/api/models/lumina_nextdit2d.md @@ -0,0 +1,20 @@ + + +# LuminaNextDiT2DModel + +A Next Version of Diffusion Transformer model for 2D data from [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X). + +## LuminaNextDiT2DModel + +[[autodoc]] LuminaNextDiT2DModel + diff --git a/docs/source/en/api/pipelines/lumina.md b/docs/source/en/api/pipelines/lumina.md new file mode 100644 index 000000000000..d26a27ded2fb --- /dev/null +++ b/docs/source/en/api/pipelines/lumina.md @@ -0,0 +1,88 @@ + + +# Lumina-T2X +![concepts](https://github.com/Alpha-VLLM/Lumina-T2X/assets/54879512/9f52eabb-07dc-4881-8257-6d8a5f2a0a5a) + +[Lumina-Next : Making Lumina-T2X Stronger and Faster with Next-DiT](https://github.com/Alpha-VLLM/Lumina-T2X/blob/main/assets/lumina-next.pdf) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory. + +The abstract from the paper is: + +*Lumina-T2X is a nascent family of Flow-based Large Diffusion Transformers (Flag-DiT) that establishes a unified framework for transforming noise into various modalities, such as images and videos, conditioned on text instructions. Despite its promising capabilities, Lumina-T2X still encounters challenges including training instability, slow inference, and extrapolation artifacts. In this paper, we present Lumina-Next, an improved version of Lumina-T2X, showcasing stronger generation performance with increased training and inference efficiency. We begin with a comprehensive analysis of the Flag-DiT architecture and identify several suboptimal components, which we address by introducing the Next-DiT architecture with 3D RoPE and sandwich normalizations. To enable better resolution extrapolation, we thoroughly compare different context extrapolation methods applied to text-to-image generation with 3D RoPE, and propose Frequency- and Time-Aware Scaled RoPE tailored for diffusion transformers. Additionally, we introduce a sigmoid time discretization schedule to reduce sampling steps in solving the Flow ODE and the Context Drop method to merge redundant visual tokens for faster network evaluation, effectively boosting the overall sampling speed. Thanks to these improvements, Lumina-Next not only improves the quality and efficiency of basic text-to-image generation but also demonstrates superior resolution extrapolation capabilities and multilingual generation using decoder-based LLMs as the text encoder, all in a zero-shot manner. To further validate Lumina-Next as a versatile generative framework, we instantiate it on diverse tasks including visual recognition, multi-view, audio, music, and point cloud generation, showcasing strong performance across these domains. By releasing all codes and model weights at https://github.com/Alpha-VLLM/Lumina-T2X, we aim to advance the development of next-generation generative AI capable of universal modeling.* + +**Highlights**: Lumina-Next is a next-generation Diffusion Transformer that significantly enhances text-to-image generation, multilingual generation, and multitask performance by introducing the Next-DiT architecture, 3D RoPE, and frequency- and time-aware RoPE, among other improvements. + +Lumina-Next has the following components: +* It improves sampling efficiency with fewer and faster Steps. +* It uses a Next-DiT as a transformer backbone with Sandwichnorm 3D RoPE, and Grouped-Query Attention. +* It uses a Frequency- and Time-Aware Scaled RoPE. + +--- + +[Lumina-T2X: Transforming Text into Any Modality, Resolution, and Duration via Flow-based Large Diffusion Transformers](https://arxiv.org/abs/2405.05945) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory. + +The abstract from the paper is: + +*Sora unveils the potential of scaling Diffusion Transformer for generating photorealistic images and videos at arbitrary resolutions, aspect ratios, and durations, yet it still lacks sufficient implementation details. In this technical report, we introduce the Lumina-T2X family - a series of Flow-based Large Diffusion Transformers (Flag-DiT) equipped with zero-initialized attention, as a unified framework designed to transform noise into images, videos, multi-view 3D objects, and audio clips conditioned on text instructions. By tokenizing the latent spatial-temporal space and incorporating learnable placeholders such as [nextline] and [nextframe] tokens, Lumina-T2X seamlessly unifies the representations of different modalities across various spatial-temporal resolutions. This unified approach enables training within a single framework for different modalities and allows for flexible generation of multimodal data at any resolution, aspect ratio, and length during inference. Advanced techniques like RoPE, RMSNorm, and flow matching enhance the stability, flexibility, and scalability of Flag-DiT, enabling models of Lumina-T2X to scale up to 7 billion parameters and extend the context window to 128K tokens. This is particularly beneficial for creating ultra-high-definition images with our Lumina-T2I model and long 720p videos with our Lumina-T2V model. Remarkably, Lumina-T2I, powered by a 5-billion-parameter Flag-DiT, requires only 35% of the training computational costs of a 600-million-parameter naive DiT. Our further comprehensive analysis underscores Lumina-T2X's preliminary capability in resolution extrapolation, high-resolution editing, generating consistent 3D views, and synthesizing videos with seamless transitions. We expect that the open-sourcing of Lumina-T2X will further foster creativity, transparency, and diversity in the generative AI community.* + + +You can find the original codebase at [Alpha-VLLM](https://github.com/Alpha-VLLM/Lumina-T2X) and all the available checkpoints at [Alpha-VLLM Lumina Family](https://huggingface.co/collections/Alpha-VLLM/lumina-family-66423205bedb81171fd0644b). + +**Highlights**: Lumina-T2X supports Any Modality, Resolution, and Duration. + +Lumina-T2X has the following components: +* It uses a Flow-based Large Diffusion Transformer as the backbone +* It supports different any modalities with one backbone and corresponding encoder, decoder. + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +### Inference (Text-to-Image) + +Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. + +First, load the pipeline: + +```python +from diffusers import LuminaText2ImgPipeline +import torch + +pipeline = LuminaText2ImgPipeline.from_pretrained( + "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 +).to("cuda") +``` + +Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`: + +```python +pipeline.transformer.to(memory_format=torch.channels_last) +pipeline.vae.to(memory_format=torch.channels_last) +``` + +Finally, compile the components and run inference: + +```python +pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) +pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True) + +image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures").images[0] +``` + +## LuminaText2ImgPipeline + +[[autodoc]] LuminaText2ImgPipeline + - all + - __call__ + diff --git a/docs/source/en/api/schedulers/flow_match_heun_discrete.md b/docs/source/en/api/schedulers/flow_match_heun_discrete.md new file mode 100644 index 000000000000..642f8ffc7dcc --- /dev/null +++ b/docs/source/en/api/schedulers/flow_match_heun_discrete.md @@ -0,0 +1,18 @@ + + +# FlowMatchHeunDiscreteScheduler + +`FlowMatchHeunDiscreteScheduler` is based on the flow-matching sampling introduced in [EDM](https://arxiv.org/abs/2403.03206). + +## FlowMatchHeunDiscreteScheduler +[[autodoc]] FlowMatchHeunDiscreteScheduler diff --git a/scripts/convert_lumina_to_diffusers.py b/scripts/convert_lumina_to_diffusers.py new file mode 100644 index 000000000000..a12625d1376f --- /dev/null +++ b/scripts/convert_lumina_to_diffusers.py @@ -0,0 +1,142 @@ +import argparse +import os + +import torch +from safetensors.torch import load_file +from transformers import AutoModel, AutoTokenizer + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline + + +def main(args): + # checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I + all_sd = load_file(args.origin_ckpt_path, device="cpu") + converted_state_dict = {} + # pad token + converted_state_dict["pad_token"] = all_sd["pad_token"] + + # patch embed + converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"] + converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"] + + # time and caption embed + converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"] + converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"] + converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"] + converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"] + converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"] + converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"] + converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"] + converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"] + + for i in range(24): + # adaln + converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"] + converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"] + converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"] + + # qkv + converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"] + converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"] + converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"] + + # cap + converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"] + converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"] + converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"] + + # output + converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"] + + # attention + # qk norm + converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"] + converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"] + + converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"] + converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"] + + converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"] + converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"] + + converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"] + converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"] + + # attention norm + converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"] + converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"] + converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"] + + # feed forward + converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"] + converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"] + converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"] + + # feed forward norm + converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"] + converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"] + + # final layer + converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"] + converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"] + + converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"] + converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"] + + # Lumina-Next-SFT 2B + transformer = LuminaNextDiT2DModel( + sample_size=128, + patch_size=2, + in_channels=4, + hidden_size=2304, + num_layers=24, + num_attention_heads=32, + num_kv_heads=8, + multiple_of=256, + ffn_dim_multiplier=None, + norm_eps=1e-5, + learn_sigma=True, + qk_norm=True, + cross_attention_dim=2048, + scaling_factor=1.0, + ) + transformer.load_state_dict(converted_state_dict, strict=True) + + num_model_params = sum(p.numel() for p in transformer.parameters()) + print(f"Total number of transformer parameters: {num_model_params}") + + if args.only_transformer: + transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) + else: + scheduler = FlowMatchEulerDiscreteScheduler() + + vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32) + + tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") + text_encoder = AutoModel.from_pretrained("google/gemma-2b") + + pipeline = LuminaText2ImgPipeline( + tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler + ) + pipeline.save_pretrained(args.dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--image_size", + default=1024, + type=int, + choices=[256, 512, 1024], + required=False, + help="Image size of pretrained model, either 512 or 1024.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") + parser.add_argument("--only_transformer", default=True, type=bool, required=True) + + args = parser.parse_args() + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6f80cab0f357..85f3b7a127f7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -88,6 +88,7 @@ "HunyuanDiT2DMultiControlNetModel", "I2VGenXLUNet", "Kandinsky3UNet", + "LuminaNextDiT2DModel", "ModelMixin", "MotionAdapter", "MultiAdapter", @@ -162,6 +163,7 @@ "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", "FlowMatchEulerDiscreteScheduler", + "FlowMatchHeunDiscreteScheduler", "HeunDiscreteScheduler", "IPNDMScheduler", "KarrasVeScheduler", @@ -270,6 +272,7 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LuminaText2ImgPipeline", "MarigoldDepthPipeline", "MarigoldNormalsPipeline", "MusicLDMPipeline", @@ -509,6 +512,7 @@ HunyuanDiT2DMultiControlNetModel, I2VGenXLUNet, Kandinsky3UNet, + LuminaNextDiT2DModel, ModelMixin, MotionAdapter, MultiAdapter, @@ -580,6 +584,7 @@ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler, + FlowMatchHeunDiscreteScheduler, HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, @@ -669,6 +674,7 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LuminaText2ImgPipeline, MarigoldDepthPipeline, MarigoldNormalsPipeline, MusicLDMPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f3fda596aa71..4d80aec6935b 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -41,6 +41,7 @@ _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] + _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"] _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] @@ -85,6 +86,7 @@ DiTTransformer2DModel, DualTransformer2DModel, HunyuanDiT2DModel, + LuminaNextDiT2DModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2a81f357d48b..d35f875f144c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -19,7 +19,7 @@ from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm @@ -527,6 +527,56 @@ def forward( return hidden_states +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + inner_dim = int(2 * inner_dim / 3) + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.silu = FP32SiLU() + + def forward(self, x): + return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) + + @maybe_allow_in_graph class TemporalBasicTransformerBlock(nn.Module): r""" diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9d495695e330..ac773ba48103 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -94,6 +94,7 @@ def __init__( query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, + kv_heads: Optional[int] = None, dim_head: int = 64, dropout: float = 0.0, bias: bool = False, @@ -118,6 +119,7 @@ def __init__( ): super().__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads self.query_dim = query_dim self.use_bias = bias self.is_cross_attention = cross_attention_dim is not None @@ -168,6 +170,10 @@ def __init__( elif qk_norm == "layer_norm": self.norm_q = nn.LayerNorm(dim_head, eps=eps) self.norm_k = nn.LayerNorm(dim_head, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applys qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") @@ -198,15 +204,15 @@ def __init__( if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) @@ -1594,6 +1600,102 @@ def __call__( return hidden_states +class LuminaAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[torch.Tensor] = None, + key_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Apply Query-Key Norm if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.view(batch_size, -1, attn.heads, head_dim) + + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply RoPE if needed + if query_rotary_emb is not None: + query = apply_rotary_emb(query, query_rotary_emb, use_real=False) + if key_rotary_emb is not None: + key = apply_rotary_emb(key, key_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Apply proportional attention if true + if key_rotary_emb is None: + softmax_scale = None + else: + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # perform Grouped-qurey Attention (GQA) + n_rep = attn.heads // kv_heads + if n_rep >= 1: + key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).to(dtype) + + return hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index cb6cb065dd32..8bc30f7cabcf 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -230,6 +230,52 @@ def forward(self, latent): return (latent + pos_embed).to(latent.dtype) +class LuminaPatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for Lumina-T2X""" + + def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=embed_dim, + bias=bias, + ) + + def forward(self, x, freqs_cis): + """ + Patchifies and embeds the input tensor(s). + + Args: + x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded. + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified + and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the + frequency tensor(s). + """ + freqs_cis = freqs_cis.to(x[0].device) + patch_height = patch_width = self.patch_size + batch_size, channel, height, width = x.size() + height_tokens, width_tokens = height // patch_height, width // patch_width + + x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute( + 0, 2, 4, 1, 3, 5 + ) + x = x.flatten(3) + x = self.proj(x) + x = x.flatten(1, 2) + + mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) + + return ( + x, + mask, + [(height, width)] * batch_size, + freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), + ) + + def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. @@ -274,7 +320,25 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): return emb -def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): +def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): + assert embed_dim % 4 == 0 + + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (H, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (W, D/4) + emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) + emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) + + emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0 +): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -289,13 +353,17 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float Scaling factor for frequency computation. Defaults to 10000.0. use_real (`bool`, *optional*): If True, return real part and imaginary part separately. Otherwise, return complex numbers. - + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ if isinstance(pos, int): pos = np.arange(pos) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + theta = theta * ntk_factor + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] if use_real: @@ -310,6 +378,7 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float def apply_rotary_emb( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -325,16 +394,23 @@ def apply_rotary_emb( Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ - cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] - cos, sin = cos.to(x.device), sin.to(x.device) + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - return out + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) class TimestepEmbedding(nn.Module): @@ -778,6 +854,40 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde return conditioning +class LuminaCombinedTimestepCaptionEmbedding(nn.Module): + def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): + super().__init__() + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 + ) + + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) + + self.caption_embedder = nn.Sequential( + nn.LayerNorm(cross_attention_dim), + nn.Linear( + cross_attention_dim, + hidden_size, + bias=True, + ), + ) + + def forward(self, timestep, caption_feat, caption_mask): + # timestep embedding: + time_freq = self.time_proj(timestep) + time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) + + # caption condition embedding: + caption_mask_float = caption_mask.float().unsqueeze(-1) + caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1) + caption_feats_pool = caption_feats_pool.to(caption_feat) + caption_embed = self.caption_embedder(caption_feats_pool) + + conditioning = time_embed + caption_embed + + return conditioning + + class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index a1a7ce91d754..16d76faad0c5 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -22,7 +22,10 @@ from ..utils import is_torch_version from .activations import get_activation -from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings +from .embeddings import ( + CombinedTimestepLabelEmbeddings, + PixArtAlphaCombinedTimestepSizeEmbeddings, +) class AdaLayerNorm(nn.Module): @@ -84,6 +87,37 @@ def forward( return x, gate_msa, shift_mlp, scale_mlp, gate_mlp +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # emb = self.emb(timestep, encoder_hidden_states, encoder_mask) + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + + return x, gate_msa, scale_mlp, gate_mlp + + class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). @@ -188,6 +222,54 @@ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torc return x +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + else: + raise ValueError(f"unknown norm_type {norm_type}") + # linear_2 + if out_dim is not None: + self.linear_2 = nn.Linear( + embedding_dim, + out_dim, + bias=bias, + ) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + if is_torch_version(">=", "2.1.0"): LayerNorm = nn.LayerNorm else: diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 04bd21b70737..533fb621269e 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -5,6 +5,7 @@ from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel + from .lumina_nextdit2d import LuminaNextDiT2DModel from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py new file mode 100644 index 000000000000..d4f5b4658542 --- /dev/null +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -0,0 +1,340 @@ +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import LuminaFeedForward +from ..attention_processor import Attention, LuminaAttnProcessor2_0 +from ..embeddings import ( + LuminaCombinedTimestepCaptionEmbedding, + LuminaPatchEmbed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LuminaNextDiTBlock(nn.Module): + """ + A LuminaNextDiTBlock for LuminaNextDiT2DModel. + + Parameters: + dim (`int`): Embedding dimension of the input features. + num_attention_heads (`int`): Number of attention heads. + num_kv_heads (`int`): + Number of attention heads in key and value features (if using GQA), or set to None for the same as query. + multiple_of (`int`): The number of multiple of ffn layer. + ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension. + norm_eps (`float`): The eps for norm layer. + qk_norm (`bool`): normalization for query and key. + cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states. + norm_elementwise_affine (`bool`, *optional*, defaults to True), + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + cross_attention_dim: int, + norm_elementwise_affine: bool = True, + ) -> None: + super().__init__() + self.head_dim = dim // num_attention_heads + + self.gate = nn.Parameter(torch.zeros([num_attention_heads])) + + # Self-attention + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="layer_norm_across_heads" if qk_norm else None, + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=LuminaAttnProcessor2_0(), + ) + self.attn1.to_out = nn.Identity() + + # Cross-attention + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + qk_norm="layer_norm_across_heads" if qk_norm else None, + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=LuminaAttnProcessor2_0(), + ) + + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=norm_elementwise_affine, + ) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + temb: torch.Tensor, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Perform a forward pass through the LuminaNextDiTBlock. + + Parameters: + hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock. + attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. + image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. + encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. + encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. + temb (`torch.Tensor`): Timestep embedding with text prompt embedding. + cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention. + """ + residual = hidden_states + + # Self-attention + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + self_attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + query_rotary_emb=image_rotary_emb, + key_rotary_emb=image_rotary_emb, + **cross_attention_kwargs, + ) + + # Cross-attention + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + cross_attn_output = self.attn2( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=encoder_mask, + query_rotary_emb=image_rotary_emb, + key_rotary_emb=None, + **cross_attention_kwargs, + ) + cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1) + mixed_attn_output = self_attn_output + cross_attn_output + mixed_attn_output = mixed_attn_output.flatten(-2) + # linear proj + hidden_states = self.attn2.to_out[0](mixed_attn_output) + + hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states) + + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + + return hidden_states + + +class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): + """ + LuminaNextDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): + The size of each patch in the image. This parameter defines the resolution of patches fed into the model. + in_channels (`int`, *optional*, defaults to 4): + The number of input channels for the model. Typically, this matches the number of channels in the input + images. + hidden_size (`int`, *optional*, defaults to 4096): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + num_layers (`int`, *optional*, default to 32): + The number of layers in the model. This defines the depth of the neural network. + num_attention_heads (`int`, *optional*, defaults to 32): + The number of attention heads in each attention layer. This parameter specifies how many separate attention + mechanisms are used. + num_kv_heads (`int`, *optional*, defaults to 8): + The number of key-value heads in the attention mechanism, if different from the number of attention heads. + If None, it defaults to num_attention_heads. + multiple_of (`int`, *optional*, defaults to 256): + A factor that the hidden size should be a multiple of. This can help optimize certain hardware + configurations. + ffn_dim_multiplier (`float`, *optional*): + A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on + the model configuration. + norm_eps (`float`, *optional*, defaults to 1e-5): + A small value added to the denominator for numerical stability in normalization layers. + learn_sigma (`bool`, *optional*, defaults to True): + Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in + predictions. + qk_norm (`bool`, *optional*, defaults to True): + Indicates if the queries and keys in the attention mechanism should be normalized. + cross_attention_dim (`int`, *optional*, defaults to 2048): + The dimensionality of the text embeddings. This parameter defines the size of the text representations used + in the model. + scaling_factor (`float`, *optional*, defaults to 1.0): + A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the + overall scale of the model's operations. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: Optional[int] = 2, + in_channels: Optional[int] = 4, + hidden_size: Optional[int] = 2304, + num_layers: Optional[int] = 32, + num_attention_heads: Optional[int] = 32, + num_kv_heads: Optional[int] = None, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: Optional[float] = 1e-5, + learn_sigma: Optional[bool] = True, + qk_norm: Optional[bool] = True, + cross_attention_dim: Optional[int] = 2048, + scaling_factor: Optional[float] = 1.0, + ) -> None: + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.scaling_factor = scaling_factor + + self.patch_embedder = LuminaPatchEmbed( + patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True + ) + + self.pad_token = nn.Parameter(torch.empty(hidden_size)) + + self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding( + hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim + ) + + self.layers = nn.ModuleList( + [ + LuminaNextDiTBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) + + assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + cross_attention_kwargs: Dict[str, Any] = None, + return_dict=True, + ) -> torch.Tensor: + """ + Forward pass of LuminaNextDiT. + + Parameters: + hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). + timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,). + encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). + encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). + """ + hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb) + image_rotary_emb = image_rotary_emb.to(hidden_states.device) + + temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask) + + encoder_mask = encoder_mask.bool() + for layer in self.layers: + hidden_states = layer( + hidden_states, + mask, + image_rotary_emb, + encoder_hidden_states, + encoder_mask, + temb=temb, + cross_attention_kwargs=cross_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + + # unpatchify + height_tokens = width_tokens = self.patch_size + height, width = img_size[0] + batch_size = hidden_states.size(0) + sequence_length = (height // height_tokens) * (width // width_tokens) + hidden_states = hidden_states[:, :sequence_length].view( + batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels + ) + output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 4f135c9e43aa..aee2c609281f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -209,6 +209,7 @@ "LEditsPPPipelineStableDiffusionXL", ] ) + _import_structure["lumina"] = ["LuminaText2ImgPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", @@ -486,6 +487,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) + from .lumina import LuminaText2ImgPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldNormalsPipeline, diff --git a/src/diffusers/pipelines/lumina/__init__.py b/src/diffusers/pipelines/lumina/__init__.py new file mode 100644 index 000000000000..ca1396359721 --- /dev/null +++ b/src/diffusers/pipelines/lumina/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_lumina import LuminaText2ImgPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py new file mode 100644 index 000000000000..41c88cbd2b1e --- /dev/null +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -0,0 +1,897 @@ +# Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import List, Optional, Tuple, Union + +import torch +from transformers import AutoModel, AutoTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...models.embeddings import get_2d_rotary_pos_embed_lumina +from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LuminaText2ImgPipeline + + >>> pipe = LuminaText2ImgPipeline.from_pretrained( + ... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 + ... ).cuda() + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LuminaText2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Lumina-T2I. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`AutoModel`]): + Frozen text-encoder. Lumina-T2I uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. + tokenizer (`AutoModel`): + Tokenizer of class + [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + transformer: LuminaNextDiT2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: AutoModel, + tokenizer: AutoTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.max_sequence_length = 256 + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.default_image_size = self.default_sample_size * self.vae_scale_factor + + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clean_caption: Optional[bool] = False, + max_length: Optional[int] = None, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + pad_to_multiple_of=8, + max_length=self.max_sequence_length, + truncation=True, + padding=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {self.max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-2] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + clean_caption=clean_caption, + ) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + # Padding negative prompt to the same length with prompt + prompt_max_length = prompt_embeds.shape[1] + negative_text_inputs = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=prompt_max_length, + truncation=True, + return_tensors="pt", + ) + negative_text_input_ids = negative_text_inputs.input_ids.to(device) + negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device) + # Get the negative prompt embeddings + negative_prompt_embeds = self.text_encoder( + negative_text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ) + + negative_dtype = self.text_encoder.dtype + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + _, seq_len, _ = negative_prompt_embeds.shape + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device) + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Union[str, List[str]] = None, + sigmas: List[float] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clean_caption: bool = True, + max_sequence_length: int = 256, + scaling_watershed: Optional[float] = 1.0, + proportional_attn: Optional[bool] = True, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to 120): + Maximum sequence length to use with the `prompt`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + cross_attention_kwargs = {} + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if proportional_attn: + cross_attention_kwargs["base_sequence_length"] = (self.default_image_size // 16) ** 2 + + scaling_factor = math.sqrt(width * height / self.default_image_size**2) + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor( + [current_timestep], + dtype=dtype, + device=latent_model_input.device, + ) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps + + # prepare image_rotary_emb for positional encoding + # dynamic scaling_factor for different resolution. + # NOTE: For `Time-aware` denosing mechanism from Lumina-Next + # https://arxiv.org/abs/2406.18583, Sec 2.3 + # NOTE: We should compute different image_rotary_emb with different timestep. + if current_timestep[0] < scaling_watershed: + linear_factor = scaling_factor + ntk_factor = 1.0 + else: + linear_factor = 1.0 + ntk_factor = scaling_factor + image_rotary_emb = get_2d_rotary_pos_embed_lumina( + self.transformer.head_dim, + 384, + 384, + linear_factor=linear_factor, + ntk_factor=ntk_factor, + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=current_timestep, + encoder_hidden_states=prompt_embeds, + encoder_mask=prompt_attention_mask, + image_rotary_emb=image_rotary_emb, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # perform guidance scale + # NOTE: For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + if do_classifier_free_guidance: + noise_pred_eps, noise_pred_rest = noise_pred[:, :3], noise_pred[:, 3:] + noise_pred_cond_eps, noise_pred_uncond_eps = torch.split( + noise_pred_eps, len(noise_pred_eps) // 2, dim=0 + ) + noise_pred_half = noise_pred_uncond_eps + guidance_scale * ( + noise_pred_cond_eps - noise_pred_uncond_eps + ) + noise_pred_eps = torch.cat([noise_pred_half, noise_pred_half], dim=0) + + noise_pred = torch.cat([noise_pred_eps, noise_pred_rest], dim=1) + noise_pred, _ = noise_pred.chunk(2, dim=0) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + noise_pred = -noise_pred + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + progress_bar.update() + + if not output_type == "latent": + latents = latents / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 33e48fc3b0dc..dfee479bfa96 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -57,6 +57,7 @@ _import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"] _import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"] _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] + _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] @@ -153,6 +154,7 @@ from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py new file mode 100644 index 000000000000..d9a3ca2d4b0a --- /dev/null +++ b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py @@ -0,0 +1,321 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Heun scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + self.timesteps = timesteps.to(device=device) + + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + # empty dt and derivative + self.prev_derivative = None + self.dt = None + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + @property + def state_in_first_order(self): + return self.dt is None + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + else: + # 2nd order / Heun's method + sigma = self.sigmas[self.step_index - 1] + sigma_next = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + if self.state_in_first_order: + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + denoised = sample - model_output * sigma + # 2. convert to an ODE derivative for 1st order + derivative = (sample - denoised) / sigma_hat + # 3. Delta timestep + dt = sigma_next - sigma_hat + + # store for 2nd order step + self.prev_derivative = derivative + self.dt = dt + self.sample = sample + else: + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + denoised = sample - model_output * sigma_next + # 2. 2nd order / Heun's method + derivative = (sample - denoised) / sigma_next + derivative = 0.5 * (self.prev_derivative + derivative) + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.prev_derivative = None + self.dt = None + self.sample = None + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 354ce7e0ba34..cd55c075f8bf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -197,6 +197,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LuminaNextDiT2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ModelMixin(metaclass=DummyObject): _backends = ["torch"] @@ -1095,6 +1110,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class FlowMatchHeunDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HeunDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a1bb667128df..482ac39de919 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -722,6 +722,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LuminaText2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class MarigoldDepthPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/lumina/__init__.py b/tests/pipelines/lumina/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py new file mode 100644 index 000000000000..a53758ce2808 --- /dev/null +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -0,0 +1,179 @@ +import gc +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline +from diffusers.utils.testing_utils import ( + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = LuminaText2ImgPipeline + params = frozenset( + [ + "prompt", + "height", + "width", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "negative_prompt"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LuminaNextDiT2DModel( + sample_size=16, + patch_size=2, + in_channels=4, + hidden_size=24, + num_layers=2, + num_attention_heads=3, + num_kv_heads=1, + multiple_of=16, + ffn_dim_multiplier=None, + norm_eps=1e-5, + learn_sigma=True, + qk_norm=True, + cross_attention_dim=32, + scaling_factor=1.0, + ) + torch.manual_seed(0) + vae = AutoencoderKL() + + scheduler = FlowMatchEulerDiscreteScheduler() + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") + + torch.manual_seed(0) + config = GemmaConfig( + head_dim=4, + hidden_size=32, + intermediate_size=37, + num_attention_heads=4, + num_hidden_layers=2, + num_key_value_heads=4, + ) + text_encoder = GemmaForCausalLM(config) + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder.eval(), + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + } + return inputs + + def test_lumina_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + do_classifier_free_guidance = inputs["guidance_scale"] > 1 + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = pipe.encode_prompt( + prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + device=torch_device, + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + +@slow +@require_torch_gpu +class LuminaText2ImgPipelineSlowTests(unittest.TestCase): + pipeline_class = LuminaText2ImgPipeline + repo_id = "Alpha-VLLM/Lumina-Next-SFT-diffusers" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + return { + "prompt": "A photo of a cat", + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "generator": generator, + } + + def test_lumina_inference(self): + pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + inputs = self.get_inputs(torch_device) + + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10] + expected_slice = np.array( + [ + [0.17773438, 0.18554688, 0.22070312], + [0.046875, 0.06640625, 0.10351562], + [0.0, 0.0, 0.02148438], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + dtype=np.float32, + ) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) + + assert max_diff < 1e-4