From a2345d54823f966589edaf69a4fbfcdb8b96e5c1 Mon Sep 17 00:00:00 2001 From: Yujie He Date: Fri, 12 Apr 2024 14:07:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B4=E7=90=86=E4=BA=86sample=5Ft2v,?= =?UTF-8?q?=E4=BD=BF=E5=A4=A7=E5=AE=B6=E9=83=A8=E7=BD=B2=E6=97=B6=E6=9B=B4?= =?UTF-8?q?=E7=AE=80=E5=8D=95=E3=80=82=E6=B7=BB=E5=8A=A0=E4=BA=86=E4=B8=80?= =?UTF-8?q?=E4=BA=9Btype=20hint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Yujie He --- opensora/sample/pipeline_videogen.py | 19 ++- opensora/sample/sample_t2v.py | 231 ++++++++++++++++----------- 2 files changed, 154 insertions(+), 96 deletions(-) diff --git a/opensora/sample/pipeline_videogen.py b/opensora/sample/pipeline_videogen.py index 303414b8c..220acb4e6 100644 --- a/opensora/sample/pipeline_videogen.py +++ b/opensora/sample/pipeline_videogen.py @@ -19,11 +19,8 @@ from typing import Callable, List, Optional, Tuple, Union import torch -import einops -from einops import rearrange from transformers import T5EncoderModel, T5Tokenizer -from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL, Transformer2DModel from diffusers.schedulers import DPMSolverMultistepScheduler from diffusers.utils import ( @@ -501,8 +498,18 @@ def _clean_caption(self, caption): return caption.strip() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, - latents=None): + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + video_length: int, + height: int, + width: int, + dtype: torch.dtype, + device: Union[str, torch.device], + generator: Optional[torch.Generator], + latents: Optional[torch.FloatTensor]=None + ): shape = ( batch_size, num_channels_latents, video_length, self.vae.latent_size[0], self.vae.latent_size[1]) if isinstance(generator, list) and len(generator) != batch_size: @@ -750,7 +757,7 @@ def __call__( return VideoPipelineOutput(video=video) - def decode_latents(self, latents): + def decode_latents(self, latents: torch.FloatTensor): video = self.vae.decode(latents) # video = self.vae.decode(latents / 0.18215) # video = rearrange(video, 'b c t h w -> b t c h w').contiguous() diff --git a/opensora/sample/sample_t2v.py b/opensora/sample/sample_t2v.py index 045e48097..bef679822 100644 --- a/opensora/sample/sample_t2v.py +++ b/opensora/sample/sample_t2v.py @@ -1,101 +1,170 @@ import math import os -import torch import argparse -import torchvision +import os, sys +from typing import List, Union +import imageio +import torch +from torchvision.utils import save_image from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler -from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder -from omegaconf import OmegaConf -from torchvision.utils import save_image -from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer - -import os, sys +from transformers import T5EncoderModel, T5Tokenizer -from opensora.models.ae import ae_stride_config, getae, getae_wrapper -from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper +from opensora.models.ae import ae_stride_config, getae_wrapper from opensora.models.diffusion.latte.modeling_latte import LatteT2V -from opensora.models.text_encoder import get_text_enc from opensora.utils.utils import save_video_grid - sys.path.append(os.path.split(sys.path[0])[0]) from pipeline_videogen import VideoGenPipeline -import imageio - -def main(args): - # torch.manual_seed(args.seed) - torch.set_grad_enabled(False) - device = "cuda" if torch.cuda.is_available() else "cpu" - - vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16) +def get_models(args: argparse.Namespace, device: str): + vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", ).to(device, dtype=torch.float16) if args.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor # Load model: - transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) + transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16).to(device) transformer_model.force_images = args.force_images - tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir") - text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) - - video_length, image_size = transformer_model.config.video_length, int(args.version.split('x')[1]) - latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2]) - vae.latent_size = latent_size - if args.force_images: - video_length = 1 - ext = 'jpg' - else: - ext = 'mp4' + tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, ) + text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, torch_dtype=torch.float16).to(device) # set eval mode transformer_model.eval() vae.eval() text_encoder.eval() + + return transformer_model, vae, text_encoder, tokenizer + + +def get_scheduler(sample_method: str): + schedulers = { + 'DDIM': DDIMScheduler(), + 'EulerDiscrete': EulerDiscreteScheduler(), + 'DDPM': DDPMScheduler(), + 'DPMSolverMultistep': DPMSolverMultistepScheduler(), + 'DPMSolverSinglestep': DPMSolverSinglestepScheduler(), + 'PNDM': PNDMScheduler(), + 'HeunDiscrete': HeunDiscreteScheduler(), + 'EulerAncestralDiscrete': EulerAncestralDiscreteScheduler(), + 'DEISMultistep': DEISMultistepScheduler(), + 'KDPM2AncestralDiscrete': KDPM2AncestralDiscreteScheduler() + } + return schedulers[sample_method] + + +def get_text_prompt(text_prompt: Union[List[str], str]): + if not isinstance(text_prompt, list): + text_prompt = [text_prompt] + + if len(text_prompt) == 1 and text_prompt[0].endswith('txt'): + text_prompt = open(text_prompt[0], 'r').readlines() + text_prompt = [i.strip() for i in text_prompt] + return text_prompt + + +def save_video(videos: torch.FloatTensor, prompt: str, args: argparse.Namespace): + """ + Save a single video (output of pipeline). + """ + # Save results + try: + if args.force_images: + videos = videos[:, 0].permute(0, 3, 1, 2) # b t h w c -> b c h w + save_image( + videos / 255.0, + os.path.join( + args.save_img_path, + prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{args.ext}'), + nrow=1, normalize=True, value_range=(0, 1)) # t c h w + + else: + imageio.mimwrite( + os.path.join( + args.save_img_path, + prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{args.ext}' + ), + videos[0], + fps=args.fps, + quality=9) # highest quality is 10, lowest is 0 + except: + print('Error when saving {}'.format(prompt)) + + return videos + + +def save_grid(video_grids: List[torch.FloatTensor], args: argparse.Namespace): + video_grids = torch.cat(video_grids, dim=0) + + # Save results + # torchvision.io.write_video(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6) + if args.force_images: + save_image( + video_grids / 255.0, + os.path.join( + args.save_img_path, + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{args.ext}'), + nrow=math.ceil(math.sqrt(len(video_grids))), + normalize=True, value_range=(0, 1)) + else: + video_grids = save_video_grid(video_grids) + imageio.mimwrite( + os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{args.ext}'), + video_grids, + fps=args.fps, + quality=9) - if args.sample_method == 'DDIM': ######### - scheduler = DDIMScheduler() - elif args.sample_method == 'EulerDiscrete': - scheduler = EulerDiscreteScheduler() - elif args.sample_method == 'DDPM': ############# - scheduler = DDPMScheduler() - elif args.sample_method == 'DPMSolverMultistep': - scheduler = DPMSolverMultistepScheduler() - elif args.sample_method == 'DPMSolverSinglestep': - scheduler = DPMSolverSinglestepScheduler() - elif args.sample_method == 'PNDM': - scheduler = PNDMScheduler() - elif args.sample_method == 'HeunDiscrete': ######## - scheduler = HeunDiscreteScheduler() - elif args.sample_method == 'EulerAncestralDiscrete': - scheduler = EulerAncestralDiscreteScheduler() - elif args.sample_method == 'DEISMultistep': - scheduler = DEISMultistepScheduler() - elif args.sample_method == 'KDPM2AncestralDiscrete': ######### - scheduler = KDPM2AncestralDiscreteScheduler() + print('save path {}'.format(args.save_img_path)) + + # save_videos_grid(video, f"./{prompt}.gif") + + +def main(args): + # torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" print('videogen_pipeline', device) + + + # Prepare models and pipeline + transformer_model, vae, text_encoder, tokenizer = get_models(args, device) + scheduler = get_scheduler(args.sample_method) videogen_pipeline = VideoGenPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, - transformer=transformer_model).to(device=device) + transformer=transformer_model, + ) + # Some pipeline configs + # videogen_pipeline.enable_sequential_cpu_offload() # videogen_pipeline.enable_xformers_memory_efficient_attention() + + # Prepare + video_length, image_size = transformer_model.config.video_length, int(args.version.split('x')[1]) + latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2]) + vae.latent_size = latent_size + if args.force_images: + video_length = 1 + args.ext = 'jpg' + else: + args.ext = 'mp4' + if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path) + + # Get text prompts + text_prompt = get_text_prompt(args.text_prompt) + + + # Video generation video_grids = [] - if not isinstance(args.text_prompt, list): - args.text_prompt = [args.text_prompt] - if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): - text_prompt = open(args.text_prompt[0], 'r').readlines() - args.text_prompt = [i.strip() for i in text_prompt] - for prompt in args.text_prompt: + for prompt in text_prompt: print('Processing the ({}) prompt'.format(prompt)) videos = videogen_pipeline(prompt, video_length=video_length, @@ -107,37 +176,18 @@ def main(args): num_images_per_prompt=1, mask_feature=True, ).video - try: - if args.force_images: - videos = videos[:, 0].permute(0, 3, 1, 2) # b t h w c -> b c h w - save_image(videos / 255.0, os.path.join(args.save_img_path, - prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), - nrow=1, normalize=True, value_range=(0, 1)) # t c h w - - else: - imageio.mimwrite( - os.path.join( - args.save_img_path, - prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}' - ), videos[0], - fps=args.fps, quality=9) # highest quality is 10, lowest is 0 - except: - print('Error when saving {}'.format(prompt)) - video_grids.append(videos) - video_grids = torch.cat(video_grids, dim=0) - - - # torchvision.io.write_video(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6) - if args.force_images: - save_image(video_grids / 255.0, os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), - nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1)) - else: - video_grids = save_video_grid(video_grids) - imageio.mimwrite(os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), video_grids, fps=args.fps, quality=9) - - print('save path {}'.format(args.save_img_path)) - - # save_videos_grid(video, f"./{prompt}.gif") + + videos = save_video(videos, prompt, args) + + # Save result + if args.save_grid: + video_grids.append(videos) + + + # Save results + if args.save_grid: + save_grid(video_grids, args) + if __name__ == "__main__": @@ -156,6 +206,7 @@ def main(args): parser.add_argument('--force_images', action='store_true') parser.add_argument('--tile_overlap_factor', type=float, default=0.25) parser.add_argument('--enable_tiling', action='store_true') + parser.add_argument('--save_grid', action='store_true', help='Save all prompts in a grid') args = parser.parse_args() main(args) \ No newline at end of file