Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error when initializing the OmniTokenizer #20

Open
dongzhuoyao opened this issue Sep 11, 2024 · 2 comments
Open

error when initializing the OmniTokenizer #20

dongzhuoyao opened this issue Sep 11, 2024 · 2 comments

Comments

@dongzhuoyao
Copy link

File "/export/scratch/ra63nev/lab/discretediffusion/OmniTokenizer/omnitokenizer.py", line 108, in init
spatial_depth=args.spatial_depth, temporal_depth=args.temporal_depth, causal_in_temporal_transformer=args.causal_in_temporal_transformer, causal_in_peg=args.causal_in_peg,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Namespace' object has no attribute 'causal_in_temporal_transformer'. Did you mean: 'casual_in_temporal_transformer'?

I tried two ckpts, all doesn;t work.

vqgan_ckpt = "./pretrained_ckpt/imagenet_k600.ckpt"

vqgan_ckpt = "./pretrained_ckpt/imagenet_ucf.ckpt"

vqgan_omni = OmniTokenizer_VQGAN.load_from_checkpoint(vqgan_ckpt, strict=False)

omni_tokenizer = vqgan_omni.to(device)
image = load_and_preprocess_image(img_path)
image = image.to(device)
indices = omni_tokenizer.encode(image)
print(
    f"image {img_path} is encoded into tokens {indices}, with shape {indices.shape}"
)
# de-tokenization
reconstructed_image = omni_tokenizer.decode(indices)
reconstructed_image = torch.clamp(reconstructed_image, 0.0, 1.0)
reconstructed_image = (
    (reconstructed_image * 255.0)
    .permute(0, 2, 3, 1)
    .to("cpu", dtype=torch.uint8)
    .numpy()[0]
)
Image.fromarray(reconstructed_image).save("reconstructed_image_omni.png")
@LisIva
Copy link

LisIva commented Sep 30, 2024

Met the same problem. Solved. Go to omnitokenizer.py, correct the strokes:

self.encoder = OmniTokenizer_Encoder(...causal_in_temporal_transformer=args.casual_in_temporal_transformer, causal_in_peg=args.casual_in_peg, ...)

self.decoder = OmniTokenizer_Decoder(...causal_in_temporal_transformer=args.casual_in_temporal_transformer, causal_in_peg=args.casual_in_peg, ...)

As you can see the authors made a typo in the word causal, that's why the attribute 'causal_...' does not exist

@NilanEkanayake
Copy link

NilanEkanayake commented Oct 3, 2024

Here's the code I use to encode and decode videos, once the errors mentioned above are corrected:

from OmniTokenizer import OmniTokenizer_VQGAN
import torch
from torchvision.io import write_video
import numpy as np
from decord import VideoReader, cpu
from einops import rearrange

device = 'cuda:0'
dtype=torch.bfloat16
vqgan = OmniTokenizer_VQGAN.load_from_checkpoint('imagenet_k600.ckpt', strict=False)
vqgan.requires_grad_(False)
vqgan.eval()
vqgan = vqgan.to(device, dtype=dtype)

video_reader = VideoReader('input.mp4', ctx=cpu(0))
fps = video_reader.get_avg_fps()

video = video_reader.get_batch(list(range(len(video_reader)))).asnumpy()
video = torch.from_numpy(video).to(dtype)
video = rearrange(video[:-3], 't h w c -> 1 c t h w') # skip last couple frames to avoid /4 errors, will change based on input frame count

video = video / 255 # (0.0-1.0)

video = video.to(device=device, dtype=dtype)


video = video - 0.5 # (-0.5-0.5)
video = video.clamp(-0.5, 0.5)

with torch.no_grad():
    tokens = vqgan.encode(video, is_image=False)
    print(tokens.shape)
    recons = vqgan.decode(tokens, is_image=False)

video_dec = recons.clamp(-0.5, 0.5)

video_dec = rearrange(video_dec.squeeze(0), 'c t h w -> t h w c') # format for output
video_dec = (video_dec + 0.5).clamp(0, 1)
video_dec = video_dec.cpu().float().numpy()

video_dec = (video_dec * 255).astype(np.uint8)  # Convert to 0-255 and write
write_video("recon.mp4", video_dec, fps=fps, options={'crf': '0'})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants