Skip to content

Commit

Permalink
fixed COIN eval, need to tune hyperparameters in new env
Browse files Browse the repository at this point in the history
  • Loading branch information
joya.chen committed Aug 15, 2024
1 parent 50d13f9 commit 755e265
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 63 deletions.
71 changes: 40 additions & 31 deletions data/coin/benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Levenshtein as lev
import Levenshtein
import numpy as np
from transformers import PreTrainedTokenizer, EvalPrediction

Expand All @@ -11,16 +11,18 @@ class COINBenchmark(COIN, StreamMixIn):

@staticmethod
def fuzzy_match(text, choices):
scores = [-lev.distance(text, choice) for choice in choices]
return scores.index(max(scores))
return min([(Levenshtein.distance(text, choice), choice) for choice in choices])[1]

def compute_metrics(self, eval_predictions: EvalPrediction, tokenizer: PreTrainedTokenizer, **kwargs):
batch_pred_tensor, sample_idxs = eval_predictions.predictions, eval_predictions.label_ids
batch_pred_tensor = batch_pred_tensor.clip(min=0)
batch_pred_tensor[batch_pred_tensor < 0] = tokenizer.bos_token_id # not use clamp(min=0), since 0 is ! in Llama-3 tokenizer and may affect matching
predictions = tokenizer.batch_decode(batch_pred_tensor, skip_special_tokens=True, clean_up_tokenization_spaces=True)
predictions = np.array([self.fuzzy_match(text, self.mapping_categories) for text in predictions])
accuracy = (predictions == np.array(self.answers)).mean()
return dict(accuracy=accuracy)
correct = 0
for prediction, label in zip(predictions, self.labels[sample_idxs]): # should be self.labels[sample_idx] to get the correct order
prediction = prediction.lower().rstrip('.')
if prediction == label or self.fuzzy_match(prediction, self.categories) == label:
correct += 1
return dict(accuracy=correct / len(predictions) * 100) # * 100

def __getitem__(self, index):
anno = self.annos[index]
Expand All @@ -36,14 +38,14 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
super().__init__(split=split, frame_fps=frame_fps, is_training=is_training, **kwargs)
self.is_training = is_training
self.frame_fps = frame_fps
self.annos = []
self.answers, self.mapping_categories = [], self.steps_categories
self.annos, self.labels = [], []
for anno in self._annos:
video_uid = anno['video_uid']
duration = self.metadata[video_uid]['duration']
steps = anno['steps']
for i in range(len(steps)):
response = steps[i]['text'].capitalize() + '.'
self.labels.append(steps[i]['text'].lower())
start_time = ceil_time_by_fps(steps[i]['start'], frame_fps, min_time=0, max_time=duration)
end_time = ceil_time_by_fps(steps[i]['end'], frame_fps, min_time=0, max_time=duration)
start_frame = int(start_time * frame_fps)
Expand All @@ -57,7 +59,8 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
'conversation': conversation,
'load_ranges': {self.metadata[video_uid]['path']: range(start_frame, end_frame)}
})
self.answers.append(self.mapping_categories.index(response))
self.labels = np.array(self.labels) # for fast indexing
self.categories = self.step_categories

def build_coin_step_train(**kwargs):
return COINStep(split='train', **kwargs)
Expand All @@ -74,14 +77,14 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
super().__init__(split=split, frame_fps=frame_fps, is_training=is_training, **kwargs)
self.is_training = is_training
self.frame_fps = frame_fps
self.annos = []
self.answers, self.mapping_categories = [], self.steps_categories
self.annos, self.labels = [], []
for anno in self._annos:
video_uid = anno['video_uid']
duration = self.metadata[video_uid]['duration']
steps = anno['steps']
for i in range(len(steps) - 1):
response = steps[i+1]['text'].capitalize() + '.'
self.labels.append(steps[i+1]['text'].lower())
start_time = ceil_time_by_fps(steps[i]['start'], frame_fps, min_time=0, max_time=duration)
end_time = ceil_time_by_fps(steps[i]['end'], frame_fps, min_time=0, max_time=duration)
start_frame = int(start_time * frame_fps)
Expand All @@ -95,7 +98,8 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
'conversation': conversation,
'load_ranges': {self.metadata[video_uid]['path']: range(start_frame, end_frame)}
})
self.answers.append(self.mapping_categories.index(response))
self.labels = np.array(self.labels) # for fast indexing
self.categories = self.step_categories

def build_coin_next_train(**kwargs):
return COINNext(split='train', **kwargs)
Expand All @@ -112,12 +116,12 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
super().__init__(split=split, frame_fps=frame_fps, is_training=is_training, **kwargs)
self.is_training = is_training
self.frame_fps = frame_fps
self.annos = []
self.answers, self.mapping_categories = [], self.tasks_categories
self.annos, self.labels = [], []
for anno in self._annos:
video_uid = anno['video_uid']
duration = self.metadata[video_uid]['duration']
response = anno['task'].capitalize() + '.'
self.labels.append(anno['task'].lower())
start_time = ceil_time_by_fps(anno['start'], frame_fps, min_time=0, max_time=duration)
end_time = ceil_time_by_fps(anno['end'], frame_fps, min_time=0, max_time=duration)
start_frame = int(start_time * frame_fps)
Expand All @@ -131,7 +135,8 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
'conversation': conversation,
'load_ranges': {self.metadata[video_uid]['path']: range(start_frame, end_frame)}
})
self.answers.append(self.mapping_categories.index(response))
self.labels = np.array(self.labels) # for fast indexing
self.categories = self.task_categories

def build_coin_task_train(**kwargs):
return COINTask(split='train', **kwargs)
Expand All @@ -149,8 +154,7 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
super().__init__(split=split, frame_fps=frame_fps, is_training=is_training, **kwargs)
self.is_training = is_training
self.frame_fps = frame_fps
self.annos = []
self.answers, self.mapping_categories = [], self.steps_categories
self.annos, self.labels = [], []
for anno in self._annos:
video_uid = anno['video_uid']
duration = self.metadata[video_uid]['duration']
Expand All @@ -168,30 +172,34 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
{"role": "stream", 'num_frames': end_frame - start_frame, 'learn': True}
]
response = next_steps[0]['text'].capitalize() + '.'
self.labels.append(np.array([next_steps[0]['text'].lower()]))
else:
conversation = [
COINProcedure.user_message(num_next_steps),
{"role": "stream", 'num_frames': end_frame - start_frame, 'learn': True}
]
response = '\n'.join(f"{i+1}. {s['text'].capitalize()}." for i, s in enumerate(next_steps))
self.labels.append(np.array([s['text'].lower() for s in next_steps]))
conversation.append({"role": "assistant", "content": response, 'learn': True})
self.annos.append({
'conversation': conversation,
'load_ranges': {self.metadata[video_uid]['path']: range(start_frame, end_frame)}
})
self.answers.append([self.mapping_categories.index(step['text'].capitalize() + '.') for step in next_steps])
self.categories = self.step_categories

def compute_metrics(self, eval_predictions: EvalPrediction, tokenizer: PreTrainedTokenizer, **kwargs):
batch_pred_tensor, sample_idxs = eval_predictions.predictions, eval_predictions.label_ids
batch_pred_tensor = batch_pred_tensor.clip(min=0)
batch_pred_text = tokenizer.batch_decode(batch_pred_tensor, skip_special_tokens=True, clean_up_tokenization_spaces=True)
predictions = []
for pred_text in batch_pred_text:
pred_steps = pred_text.split('\n')
predictions.append([self.fuzzy_match(step, self.mapping_categories) for step in pred_steps])
total_num_steps = len(sum(self.answers, []))
correct_num_steps = sum([sum(1 for p, a in zip(prediction, answer) if p == a) for prediction, answer in zip(predictions, self.answers)])
return {'accuracy': correct_num_steps / total_num_steps}
batch_pred_tensor[batch_pred_tensor < 0] = tokenizer.bos_token_id
predictions = tokenizer.batch_decode(batch_pred_tensor, skip_special_tokens=True, clean_up_tokenization_spaces=True)
correct, total = 0, 0
labels = [self.labels[i] for i in sample_idxs]
for prediction_steps, label_steps in zip(predictions, labels):
for prediction_step, label_step in zip(prediction_steps.split('\n'), label_steps):
prediction_step = prediction_step.split('. ')[-1]
if prediction_step == label_step or self.fuzzy_match(prediction_step, self.categories) == label_step:
correct += 1
total += 1
return {'accuracy': correct / total * 100}

def build_coin_procedure_train(**kwargs):
return COINProcedure(split='train', **kwargs)
Expand All @@ -213,8 +221,7 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
super().__init__(split=split, frame_fps=frame_fps, is_training=is_training, **kwargs)
self.is_training = is_training
self.frame_fps = frame_fps
self.annos = []
self.answers, self.mapping_categories = [], self.steps_categories
self.annos, self.labels = [], []
for anno in self._annos:
video_uid = anno['video_uid']
duration = self.metadata[video_uid]['duration']
Expand All @@ -232,18 +239,20 @@ def __init__(self, *, split: str, frame_fps: int, is_training: bool, **kwargs):
{"role": "stream", 'num_frames': end_frame - start_frame, 'learn': True}
]
response = next_steps[0]['text'].capitalize() + '.'
self.labels.append([next_steps[0]['text'].lower()])
else:
conversation = [
COINTaskProcedure.get_query_multi(anno['task'], num_next_steps),
{"role": "stream", 'num_frames': end_frame - start_frame, 'learn': True}
]
response = '\n'.join(f"{i+1}. {s['text'].capitalize()}." for i, s in enumerate(next_steps))
self.labels.append([s['text'].lower() for s in next_steps])
conversation.append({"role": "assistant", "content": response, 'learn': True})
self.annos.append({
'conversation': conversation,
'load_ranges': {self.metadata[video_uid]['path']: range(start_frame, end_frame)}
})
self.answers.append([self.mapping_categories.index(step['text'].capitalize() + '.') for step in next_steps])
self.categories = self.step_categories

def compute_metrics(self, *args, **kwargs):
return COINProcedure.compute_metrics(self, *args, **kwargs)
Expand Down
9 changes: 3 additions & 6 deletions data/coin/coin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@ def __init__(self, split: str, vision_pretrained: str, embed_mark: str, frame_fp
text=COIN._clean_step(step['label']),
) for step in anno['annotation']],
} for video_uid, anno in annos.items() if (split in anno['subset'].lower()) and (video_uid in self.metadata)]
self.tasks_categories = list(set([v['task'].capitalize() + '.' for v in self._annos]))
self.steps_categories = list(set([step['text'].capitalize() + '.' for steps in self._annos for step in steps['steps']]))
self.task_categories = list(set([v['task'].lower() for v in self._annos]))
self.step_categories = list(set([step['text'].lower() for steps in self._annos for step in steps['steps']]))
self.annos: list[dict]

def __len__(self):
return len(self.annos)

def get_metadata(self, ):
metadata_path = f'{self.embed_dir}_metadata.json'
if os.path.exists(metadata_path):
Expand Down Expand Up @@ -70,4 +67,4 @@ def _clean_task(text):
return result.strip()

def __len__(self):
return len(self.annos)
return len(self.annos)
44 changes: 44 additions & 0 deletions data/coin/download_videos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json, os, argparse, subprocess, random, torchvision
import concurrent.futures
try:
torchvision.set_video_backend('video_reader')
except:
import av # otherwise, check if av is installed

def download_video(video_id, video_url, output_dir, ffmpeg_location=None):
output_path = os.path.join(output_dir, f'{video_id}.mp4')
if os.path.exists(output_path):
try:
ffmpeg_cmd = ["ffmpeg", "-v", "error", "-i", output_path, "-f", "null", "-"]
if ffmpeg_location:
ffmpeg_cmd[0] = os.path.join(ffmpeg_location, "ffmpeg")
subprocess.run(ffmpeg_cmd, check=True)
print(f'{output_path} has been downloaded and verified...')
return
except:
print(f'{output_path} may be broken. Downloading it again...')
os.remove(output_path)
cmd = ["yt-dlp", "--username", "oauth2", "--password", "", "-f", "mp4", "-o", output_path, video_url]
if ffmpeg_location:
cmd.extend(["--ffmpeg-location", ffmpeg_location])
subprocess.run(cmd, check=True)

def main(output_dir, json_path, num_workers, ffmpeg_location):
annotations = json.load(open(json_path, 'r'))['database']
annotations = list(annotations.items())
random.shuffle(annotations)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(download_video, video_id, annotation['video_url'], output_dir, ffmpeg_location) for video_id, annotation in annotations]
concurrent.futures.wait(futures)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Download videos in parallel using yt-dlp')
parser.add_argument('--output_dir', type=str, default='datasets/coin/videos', help='Directory to save downloaded videos')
parser.add_argument('--json_path', type=str, default='datasets/coin/coin.json', help='Path to the JSON file containing video data')
parser.add_argument('--ffmpeg', type=str, default=None)
parser.add_argument('--num_workers', type=int, default=16, help='Number of parallel downloads')

args = parser.parse_args()

os.makedirs(args.output_dir, exist_ok=True)
main(args.output_dir, args.json_path, args.num_workers, args.ffmpeg)
8 changes: 5 additions & 3 deletions data/preprocess/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
#### Sample video frames to 2 FPS and max resolution 384 (with zero padding)

```
python -m data.preprocess.ffmpeg --num_gpus 8 --frame_fps 2 --frame_resolution 384 --video_dir datasets/ego4d/v2/full_scale
python -m data.preprocess.ffmpeg --num_gpus 8 --frame_fps 2 --frame_resolution 384 --num_tasks 16 --video_dir datasets/ego4d/v2/full_scale
```

- Please run the script in ```videollm-online/``` root folder.

- The results will be saved in a new folder with '{fps}fps_{resolution}' suffix. For example, ```datasets/ego4d/v2/full_scale -> datasets/ego4d/v2/full_scale_2fps_384```.

- If you are on a cluster, you can set ```--num_nodes ... --slurm_partition ...``` to use them. The more nodes and GPUs, the faster preprocessing.
- Increase ```--num_tasks``` according to the CPU cores. 1/10 number of CPU cores is recommended.

- If you are on a cluster, you can set ```--num_nodes ... --slurm_partition ...``` to use them. The more nodes, the faster preprocessing.

#### Encode sampled 2fps_384 video frames

```
python -m data.preprocess.encode --num_gpus 8 --video_dir datasets/ego4d/v2/full_scale_2fps_384 --vision_pretrained google/siglip-large-patch16-384
python -m data.preprocess.encode --num_gpus 8 --vision_pretrained google/siglip-large-patch16-384 --video_dir datasets/ego4d/v2/full_scale_2fps_384
```

- Please run the script in ```videollm-online/``` root folder.
Expand Down
1 change: 1 addition & 0 deletions data/preprocess/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ class LiveOnePlusEncodingArguments(LiveOnePlusTrainingArguments):
timeout_min=600,
)
job = executor.submit(task)
job.results()
6 changes: 3 additions & 3 deletions data/preprocess/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@dataclass
class LiveOnePlusEncodingArguments(LiveOnePlusTrainingArguments):
num_nodes: int = 1
num_gpus: int = 8
num_tasks: int = 16
video_dir: str = 'datasets/ego4d/v2/full_scale'
slurm_partition: str = None

Expand All @@ -17,13 +17,13 @@ class LiveOnePlusEncodingArguments(LiveOnePlusTrainingArguments):
executor = submitit.AutoExecutor(folder=f"outputs/preprocess/", cluster='local' if args.num_nodes == 1 else 'slurm')
task = partial(distributed_ffmpeg, src_root=args.video_dir, resolution=args.frame_resolution, fps=args.frame_fps)
executor.update_parameters(
tasks_per_node=args.num_gpus,
tasks_per_node=args.num_tasks,
nodes=args.num_nodes,
gpus_per_node=args.num_gpus,
slurm_partition=args.slurm_partition,
cpus_per_task=10,
mem_gb=240,
slurm_time='24:00:00',
timeout_min=600,
)
job = executor.submit(task)
job.results()
2 changes: 1 addition & 1 deletion data/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __getitem__(self, *, conversation: list[dict], load_ranges: dict[str, range]
frames = load_ranges
elif load_ranges is not None:
conversation, load_ranges = self.max_frames_clip(conversation, load_ranges, self.max_num_frames)
frames = torch.cat([torch.load(path)[ranger] for path, ranger in load_ranges.items()])
frames = torch.cat([torch.load(path, weights_only=True)[ranger] for path, ranger in load_ranges.items()])
else:
frames = torch.tensor([])
# 2. prepare texts
Expand Down
6 changes: 3 additions & 3 deletions data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def ffmpeg_once(src_path: str, dst_path: str, *, fps: int = None, resolution: in
command += [dst_path]
subprocess.run(command, check=True)

def distributed_ffmpeg(*, src_root: str, fps: int = None, resolution: int = None, pad: str = '#000000', mode='bicubic', **kwargs):
def distributed_ffmpeg(*, src_root: str, fps: int = None, resolution: int = None, pad: str = '#000000', mode='bicubic'):
import submitit
env = submitit.JobEnvironment()
src_root = src_root.rstrip('/')
Expand All @@ -81,14 +81,14 @@ def distributed_ffmpeg(*, src_root: str, fps: int = None, resolution: int = None
if i % env.num_tasks != env.global_rank:
continue
dst_path = src_path.replace(src_root, dst_root)
ffmpeg_once(src_path, dst_path, fps=fps, resolution=resolution, pad=pad)
ffmpeg_once(src_path, dst_path, fps=fps, resolution=resolution, pad=pad, mode=mode)

def distributed_encode(*, src_root: str, vision_pretrained: str, vision_encode: callable, batch_size: int, embed_mark: str, save_bf16: bool = False, **kwargs):
env = submitit.JobEnvironment()
src_root = src_root.rstrip('/')
model = AutoModel.from_pretrained(vision_pretrained, device_map=f'cuda:{env.local_rank}').vision_model
model.eval()
dst_root = f"{src_root}_{embed_mark}_{vision_pretrained.replace('/', '--')}"
dst_root = f"{src_root}_{embed_mark.split('_')[-1]}_{vision_pretrained.replace('/', '--')}"
os.makedirs(dst_root, exist_ok=True)
for i, file in tqdm.tqdm(enumerate(os.listdir(src_root)), desc=f'{src_root} -> {dst_root}'):
if i % env.num_tasks != env.global_rank:
Expand Down
Loading

0 comments on commit 755e265

Please sign in to comment.