Skip to content

Commit

Permalink
[PixArt-Alpha] Introduce resolution binning (huggingface#5739)
Browse files Browse the repository at this point in the history
* feat: add resolution binning

Co-authored-by: lawrence-cj <[email protected]>

* rename

* debug

* add :test

* remove unused variable

* set resolution_binning to False.

---------

Co-authored-by: lawrence-cj <[email protected]>
  • Loading branch information
sayakpaul and lawrence-cj authored Nov 14, 2023
1 parent 5b231aa commit ed759f0
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 3 deletions.
81 changes: 80 additions & 1 deletion src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from transformers import T5EncoderModel, T5Tokenizer

from ...image_processor import VaeImageProcessor
Expand All @@ -43,7 +44,6 @@
if is_ftfy_available():
import ftfy


EXAMPLE_DOC_STRING = """
Examples:
```py
Expand All @@ -60,6 +60,42 @@
```
"""

ASPECT_RATIO_1024_BIN = {
"0.25": [512.0, 2048.0],
"0.28": [512.0, 1856.0],
"0.32": [576.0, 1792.0],
"0.33": [576.0, 1728.0],
"0.35": [576.0, 1664.0],
"0.4": [640.0, 1600.0],
"0.42": [640.0, 1536.0],
"0.48": [704.0, 1472.0],
"0.5": [704.0, 1408.0],
"0.52": [704.0, 1344.0],
"0.57": [768.0, 1344.0],
"0.6": [768.0, 1280.0],
"0.68": [832.0, 1216.0],
"0.72": [832.0, 1152.0],
"0.78": [896.0, 1152.0],
"0.82": [896.0, 1088.0],
"0.88": [960.0, 1088.0],
"0.94": [960.0, 1024.0],
"1.0": [1024.0, 1024.0],
"1.07": [1024.0, 960.0],
"1.13": [1088.0, 960.0],
"1.21": [1088.0, 896.0],
"1.29": [1152.0, 896.0],
"1.38": [1152.0, 832.0],
"1.46": [1216.0, 832.0],
"1.67": [1280.0, 768.0],
"1.75": [1344.0, 768.0],
"2.0": [1408.0, 704.0],
"2.09": [1472.0, 704.0],
"2.4": [1536.0, 640.0],
"2.5": [1600.0, 640.0],
"3.0": [1728.0, 576.0],
"4.0": [2048.0, 512.0],
}


class PixArtAlphaPipeline(DiffusionPipeline):
r"""
Expand Down Expand Up @@ -495,6 +531,38 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents

@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
"""Returns binned height and width."""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio]
return int(default_hw[0]), int(default_hw[1])

@staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
orig_height, orig_width = samples.shape[2], samples.shape[3]

# Check if resizing is needed
if orig_height != new_height or orig_width != new_width:
ratio = max(new_height / orig_height, new_width / orig_width)
resized_width = int(orig_width * ratio)
resized_height = int(orig_height * ratio)

# Resize
samples = F.interpolate(
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)

# Center Crop
start_x = (resized_width - new_width) // 2
end_x = start_x + new_width
start_y = (resized_height - new_height) // 2
end_y = start_y + new_height
samples = samples[:, :, start_y:end_y, start_x:end_x]

return samples

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand All @@ -518,6 +586,7 @@ def __call__(
callback_steps: int = 1,
clean_caption: bool = True,
mask_feature: bool = True,
use_resolution_binning: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -580,6 +649,10 @@ def __call__(
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
use_resolution_binning:
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
they are resized back to the requested resolution. Useful for generating non-square images.
Examples:
Expand All @@ -591,6 +664,10 @@ def __call__(
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
orig_height, orig_width = height, width
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)

self.check_inputs(
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
)
Expand Down Expand Up @@ -709,6 +786,8 @@ def __call__(

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.resize_and_crop_tensor(image, orig_width, orig_height)
else:
image = latents

Expand Down
9 changes: 7 additions & 2 deletions tests/pipelines/pixart/test_pixart.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "numpy",
"use_resolution_binning": False,
"output_type": "np",
}
return inputs

Expand Down Expand Up @@ -120,6 +121,7 @@ def test_save_load_optional_components(self):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}

# set all optional components to None
Expand Down Expand Up @@ -154,6 +156,7 @@ def test_save_load_optional_components(self):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}

output_loaded = pipe_loaded(**inputs)[0]
Expand Down Expand Up @@ -189,8 +192,8 @@ def test_inference_non_square_images(self):
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs, height=32, width=48).images
image_slice = image[0, -3:, -3:, -1]

self.assertEqual(image.shape, (1, 32, 48, 3))

expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
Expand Down Expand Up @@ -219,6 +222,7 @@ def test_inference_with_embeddings_and_multiple_images(self):
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"num_images_per_prompt": 2,
"use_resolution_binning": False,
}

# set all optional components to None
Expand Down Expand Up @@ -254,6 +258,7 @@ def test_inference_with_embeddings_and_multiple_images(self):
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"num_images_per_prompt": 2,
"use_resolution_binning": False,
}

output_loaded = pipe_loaded(**inputs)[0]
Expand Down

0 comments on commit ed759f0

Please sign in to comment.