Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor]: add type checking to sample image and video functions #325

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions opensora/models/diffusion/latte/modeling_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dataclasses import dataclass
from einops import rearrange, repeat
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
from diffusers.models import Transformer2DModel
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, ImagePositionalEmbeddings
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
rope_scaling_type: str = 'linear',
compress_kv_factor: int = 1,
interpolation_scale_1d: float = None,
):
) -> None:
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
Expand Down Expand Up @@ -249,15 +249,15 @@ def __init__(

self.gradient_checkpointing = False

def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, value: bool=False) -> None:
self.gradient_checkpointing = value

def make_position(self, b, t, use_image_num, h, w, device):
def make_position(self, b: int, t: int, use_image_num: int, h: int, w: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
pos_hw = self.position_getter_2d(b*(t+use_image_num), h, w, device) # fake_b = b*(t+use_image_num)
pos_t = self.position_getter_1d(b*h*w, t, device) # fake_b = b*h*w
return pos_hw, pos_t

def make_attn_mask(self, attention_mask, frame, dtype):
def make_attn_mask(self, attention_mask: torch.Tensor, frame: int, dtype: torch.dtype) -> torch.Tensor:
attention_mask = rearrange(attention_mask, 'b t h w -> (b t) 1 (h w)')
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
Expand All @@ -267,7 +267,7 @@ def make_attn_mask(self, attention_mask, frame, dtype):
attention_mask = attention_mask.to(self.dtype)
return attention_mask

def vae_to_diff_mask(self, attention_mask, use_image_num):
def vae_to_diff_mask(self, attention_mask: torch.Tensor, use_image_num: int) -> torch.Tensor:
dtype = attention_mask.dtype
# b, t+use_image_num, h, w, assume t as channel
# this version do not use 3d patch embedding
Expand All @@ -288,7 +288,7 @@ def forward(
use_image_num: int = 0,
enable_temporal_attentions: bool = True,
return_dict: bool = True,
):
) -> Union[Tuple[torch.Tensor,], Transformer3DModelOutput]:
"""
The [`Transformer2DModel`] forward method.

Expand Down Expand Up @@ -571,14 +571,14 @@ def forward(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
output = rearrange(output, '(b f) c h w -> b c f h w', b=input_batch_size).contiguous()

if not return_dict:
return (output,)

return Transformer3DModelOutput(sample=output)

@classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs):
def from_pretrained_2d(cls, pretrained_model_path: str, subfolder=None, **kwargs):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)

Expand All @@ -592,10 +592,10 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs):
return model

# depth = num_layers * 2
def LatteT2V_XL_122(**kwargs):
def LatteT2V_XL_122(**kwargs) -> LatteT2V:
return LatteT2V(num_layers=28, attention_head_dim=72, num_attention_heads=16, patch_size_t=1, patch_size=2,
norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1152, **kwargs)
def LatteT2V_D64_XL_122(**kwargs):
def LatteT2V_D64_XL_122(**kwargs) -> LatteT2V:
return LatteT2V(num_layers=28, attention_head_dim=64, num_attention_heads=18, patch_size_t=1, patch_size=2,
norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1152, **kwargs)

Expand Down
41 changes: 20 additions & 21 deletions opensora/sample/pipeline_videogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import inspect
import re
import urllib.parse as ul
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union, Any, Dict

import torch
import einops
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
vae: AutoencoderKL,
transformer: Transformer2DModel,
scheduler: DPMSolverMultistepScheduler,
):
) -> None:
super().__init__()

self.register_modules(
Expand All @@ -114,7 +114,7 @@ def __init__(
# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
def mask_text_embeddings(self, emb: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, int]:
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096
Expand All @@ -134,7 +134,7 @@ def encode_prompt(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
mask_feature: bool = True,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Encodes the prompt into text encoder hidden states.

Expand Down Expand Up @@ -280,19 +280,17 @@ def encode_prompt(
# masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0)

# print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...])

return masked_prompt_embeds, masked_negative_prompt_embeds
# return masked_prompt_embeds_, masked_negative_prompt_embeds_

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
def prepare_extra_step_kwargs(self, generator: Optional[Union[torch.Generator, List[torch.Generator]]], eta: float) -> Dict[str, Any]:
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]

accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
Expand All @@ -302,18 +300,19 @@ def prepare_extra_step_kwargs(self, generator, eta):
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator

return extra_step_kwargs

def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
):
prompt: Union[str, List[str]],
height: Optional[int],
width: Optional[int],
negative_prompt: Optional[torch.FloatTensor],
callback_steps: Optional[Callable[[int, int, torch.FloatTensor], None]],
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
) -> None:
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Expand Down Expand Up @@ -358,7 +357,7 @@ def check_inputs(
)

# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
def _text_preprocessing(self, text: Union[str, List[str]], clean_caption: bool=False) -> List[str]:
if clean_caption and not is_bs4_available():
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
Expand All @@ -379,11 +378,11 @@ def process(text: str):
else:
text = text.lower().strip()
return text

return [process(t) for t in text]

# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
def _clean_caption(self, caption: str) -> str:
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
Expand Down Expand Up @@ -501,8 +500,8 @@ def _clean_caption(self, caption):
return caption.strip()

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator,
latents=None):
def prepare_latents(self, batch_size: int, num_channels_latents: int, num_frames: int, height: Optional[int], width: Optional[int], dtype: torch.float16, device: torch.device, generator: Optional[Union[torch.Generator, List[torch.Generator]]],
latents: Optional[torch.FloatTensor]=None) -> torch.Tensor:
shape = (
batch_size,
num_channels_latents,
Expand Down Expand Up @@ -755,7 +754,7 @@ def __call__(

return VideoPipelineOutput(video=video)

def decode_latents(self, latents):
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
video = self.vae.decode(latents) # b t c h w
# b t c h w -> b t h w c
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
Expand Down
2 changes: 1 addition & 1 deletion opensora/sample/sample_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import imageio


def main(args):
def main(args: argparse.Namespace) -> None:
# torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down