-
Notifications
You must be signed in to change notification settings - Fork 3
/
LoadVideoBatchFrame.py
105 lines (83 loc) · 4.05 KB
/
LoadVideoBatchFrame.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import glob
import random
import cv2
import torch
import numpy as np
ALLOWED_VIDEO_EXT = ('.mp4', '.avi', '.mov', '.mkv')
class LoadVideoBatchFrame:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (["single_video", "incremental_video", "random"],),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"index": ("INT", {"default": 0, "min": 0, "max": 150000, "step": 1}),
"frame": ("INT", {"default": 0, "min": 0, "max": 999999, "step": 1}),
"label": ("STRING", {"default": 'Video Batch 001', "multiline": False}),
"path": ("STRING", {"default": '', "multiline": False}),
"pattern": ("STRING", {"default": '*', "multiline": False}),
},
}
RETURN_TYPES = ("IMAGE", "STRING")
RETURN_NAMES = ("frame", "filename_text")
FUNCTION = "load_batch_videos"
CATEGORY = "image/video"
def load_batch_videos(self, path, pattern='*', index=0, frame=0, mode="single_video", seed=0, label='Video Batch 001'):
if not os.path.exists(path):
raise ValueError(f"Path does not exist: {path}")
fl = self.BatchVideoLoader(path, pattern)
if mode == 'single_video':
frame_tensor, filename = fl.get_video_frame_by_id(index, frame)
if frame_tensor is None:
raise ValueError(f"No valid video frame found for index {index} and frame {frame}")
elif mode == 'incremental_video':
frame_tensor, filename = fl.get_next_video_frame(index, frame)
if frame_tensor is None:
raise ValueError("No valid video frame found")
else:
random.seed(seed)
newindex = int(random.random() * len(fl.video_paths))
frame_tensor, filename = fl.get_video_frame_by_id(newindex, frame)
if frame_tensor is None:
raise ValueError("No valid video frame found")
return (frame_tensor, filename)
class BatchVideoLoader:
def __init__(self, directory_path, pattern):
self.video_paths = []
self.load_videos(directory_path, pattern)
self.video_paths.sort()
self.index = 0
def load_videos(self, directory_path, pattern):
for file_name in glob.glob(os.path.join(glob.escape(directory_path), pattern), recursive=True):
if file_name.lower().endswith(ALLOWED_VIDEO_EXT):
abs_file_path = os.path.abspath(file_name)
self.video_paths.append(abs_file_path)
def get_video_frame_by_id(self, video_id, frame_number):
if video_id < 0 or video_id >= len(self.video_paths):
return None, None
video_path = self.video_paths[video_id]
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if frame_number >= total_frames:
frame_number = total_frames - 1
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
cap.release()
if not ret:
return None, None
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert to float32 and normalize to 0-1
frame_float = frame_rgb.astype(np.float32) / 255.0
# Convert to tensor and rearrange dimensions to [batch, height, width, channels]
frame_tensor = torch.from_numpy(frame_float)[None,]
return (frame_tensor, os.path.basename(video_path))
def get_next_video_frame(self, index, frame_number):
if index >= len(self.video_paths):
index = 0
return self.get_video_frame_by_id(index, frame_number)