diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 147e2b76e6c6..c3f667ba16be 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -19,6 +19,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch +import torch.nn.functional as F from transformers import T5EncoderModel, T5Tokenizer from ...image_processor import VaeImageProcessor @@ -43,7 +44,6 @@ if is_ftfy_available(): import ftfy - EXAMPLE_DOC_STRING = """ Examples: ```py @@ -60,6 +60,42 @@ ``` """ +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + class PixArtAlphaPipeline(DiffusionPipeline): r""" @@ -495,6 +531,38 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + @staticmethod + def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: + """Returns binned height and width.""" + ar = float(height / width) + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + default_hw = ratios[closest_ratio] + return int(default_hw[0]), int(default_hw[1]) + + @staticmethod + def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: + orig_height, orig_width = samples.shape[2], samples.shape[3] + + # Check if resizing is needed + if orig_height != new_height or orig_width != new_width: + ratio = max(new_height / orig_height, new_width / orig_width) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + # Resize + samples = F.interpolate( + samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + # Center Crop + start_x = (resized_width - new_width) // 2 + end_x = start_x + new_width + start_y = (resized_height - new_height) // 2 + end_y = start_y + new_height + samples = samples[:, :, start_y:end_y, start_x:end_x] + + return samples + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -518,6 +586,7 @@ def __call__( callback_steps: int = 1, clean_caption: bool = True, mask_feature: bool = True, + use_resolution_binning: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -580,6 +649,10 @@ def __call__( be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. + use_resolution_binning: + (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the + closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, + they are resized back to the requested resolution. Useful for generating non-square images. Examples: @@ -591,6 +664,10 @@ def __call__( # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + orig_height, orig_width = height, width + height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN) + self.check_inputs( prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds ) @@ -709,6 +786,8 @@ def __call__( if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.resize_and_crop_tensor(image, orig_width, orig_height) else: image = latents diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index a04f4e1a8804..1fb2560b29b6 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -89,7 +89,8 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, - "output_type": "numpy", + "use_resolution_binning": False, + "output_type": "np", } return inputs @@ -120,6 +121,7 @@ def test_save_load_optional_components(self): "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, + "use_resolution_binning": False, } # set all optional components to None @@ -154,6 +156,7 @@ def test_save_load_optional_components(self): "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, + "use_resolution_binning": False, } output_loaded = pipe_loaded(**inputs)[0] @@ -189,8 +192,8 @@ def test_inference_non_square_images(self): inputs = self.get_dummy_inputs(device) image = pipe(**inputs, height=32, width=48).images image_slice = image[0, -3:, -3:, -1] - self.assertEqual(image.shape, (1, 32, 48, 3)) + expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) @@ -219,6 +222,7 @@ def test_inference_with_embeddings_and_multiple_images(self): "num_inference_steps": num_inference_steps, "output_type": output_type, "num_images_per_prompt": 2, + "use_resolution_binning": False, } # set all optional components to None @@ -254,6 +258,7 @@ def test_inference_with_embeddings_and_multiple_images(self): "num_inference_steps": num_inference_steps, "output_type": output_type, "num_images_per_prompt": 2, + "use_resolution_binning": False, } output_loaded = pipe_loaded(**inputs)[0]