Skip to content

Commit

Permalink
released coin dataset scripts; preprocessing scripts with instruction…
Browse files Browse the repository at this point in the history
…s; fix some typos
  • Loading branch information
chenjoya committed Jun 24, 2024
1 parent e7e707e commit d9a721f
Show file tree
Hide file tree
Showing 17 changed files with 410 additions and 254 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ __pycache__/
*.json
*.wav
/demo/rendering/*.mp4
.DS_Store
/data/preprocess
.DS_Store
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ By passing ```--resume_from_checkpoint chenjoya/videollm-online-8b-v1plus```, th
Ensure you have Miniconda and Python version >= 3.10 installed, then run:
```sh
conda install -y pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install transformers accelerate deepspeed peft editdistance Levenshtein tensorboard gradio moviepy
pip install transformers accelerate deepspeed peft editdistance Levenshtein tensorboard gradio moviepy submitit
pip install flash-attn --no-build-isolation
```

Expand All @@ -61,9 +61,9 @@ mv ChatTTS demo/rendering/

- Download streaming dialogue data and Ego4D video features (google/siglip-large-patch16-384) from <a href="https://drive.google.com/drive/folders/1EfWu0lTpQH_p-HnwpBiZFwCE-OsUNagl?usp=sharing" target="_blank"><img alt="Data" src="https://img.shields.io/badge/📁 Data-8e44ad?color=8e44ad" /></a>

- Refer to the examples under [scripts/](scripts/)
- Distributed preprocess video frames: 2 FPS and 384 resolution, then using ```google/siglip-large-patch16-384``` to extract CLS with avg pooled 3x3 spatial tokens. Please refer to [preprocess/](preprocess/)

A more detailed instruction will be available soon.
- Refer to the examples under [scripts/](scripts/)

### Model Zoo

Expand Down
9 changes: 0 additions & 9 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@

- (optional) Also recommend to use higher ffmpeg version to get better video preprocessing:

```
wget https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz
tar xvf ffmpeg-release-amd64-static.tar.xz
rm ffmpeg-release-amd64-static.tar.xz
mv ffmpeg-6.1-amd64-static ffmpeg
```
282 changes: 148 additions & 134 deletions data/coin/benchmarks.py

Large diffs are not rendered by default.

43 changes: 28 additions & 15 deletions data/coin/coin.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import os, json, tqdm, torch
from transformers import CLIPImageProcessor

class COIN:
root = 'datasets/coin'
video_root = os.path.join(root, 'videos')
def __init__(self, split: str, model_config: dict, fps: int, load_vision_embeds: bool, **kwargs):
super().__init__(load_vision_embeds=load_vision_embeds, **kwargs)
# 1. prepare load path
vision_pretrained = model_config.vision_pretrained if not isinstance(model_config, dict) else model_config['vision_pretrained']
frame_processor = CLIPImageProcessor.from_pretrained(vision_pretrained)
crop_size = frame_processor.crop_size
self.frames_dir = f"{self.video_root}_{fps}fps_{crop_size['height']}x{crop_size['width']}"
if load_vision_embeds:
self.frames_dir += '_' + vision_pretrained.replace('/', '--')

# 2. prepare annos for all downstream benchmarks
self.metadata = get_metadata(self.frames_dir, load_vision_embeds, fps)
video_root = os.path.join(root, 'full_scale')
anno_root = os.path.join(root, 'annotations')
def __init__(self, split: str, vision_pretrained: str, embed_mark: str, frame_fps: int, **kwargs):
super().__init__(**kwargs)
self.embed_dir = f"{self.video_root}_{embed_mark}_{vision_pretrained.replace('/', '--')}"
self.frame_fps = frame_fps
self.metadata = self.get_metadata()
annos = json.load(open(os.path.join(self.root, 'coin.json')))['database']
assert split in ['train', 'test']
self._annos = [{
Expand All @@ -29,8 +22,28 @@ def __init__(self, split: str, model_config: dict, fps: int, load_vision_embeds:
text=COIN._clean_step(step['label']),
) for step in anno['annotation']],
} for video_uid, anno in annos.items() if (split in anno['subset'].lower()) and (video_uid in self.metadata)]
self.tasks_categories = list(set([v['task'].capitalize() + '.' for v in self._annos]))
self.steps_categories = list(set([step['text'].capitalize() + '.' for steps in self._annos for step in steps['steps']]))
self.annos: list[dict]

def __len__(self):
return len(self.annos)

def get_metadata(self, ):
metadata_path = f'{self.embed_dir}_metadata.json'
if os.path.exists(metadata_path):
print(f'load {metadata_path}...')
metadata = json.load(open(metadata_path))
else:
metadata = {}
for file in tqdm.tqdm(os.listdir(self.embed_dir), desc=f'prepare {metadata_path}...'):
path = os.path.join(self.embed_dir, file)
duration = (len(torch.load(path)) - 1) / self.frame_fps
key = os.path.splitext(os.path.basename(path))[0]
metadata[key] = {'duration': duration, 'path': path}
json.dump(metadata, open(metadata_path, 'w'), indent=4)
return metadata

@staticmethod
def _clean_step(step):
replaces = {
Expand All @@ -57,4 +70,4 @@ def _clean_task(text):
return result.strip()

def __len__(self):
return len(self.annos)
return len(self.annos)
18 changes: 18 additions & 0 deletions data/livechat/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import json, re

annos = json.load(open('goalstep_livechat_trainval.json'))

new_annos = []
for anno in annos:
if not anno['conversation']:
continue
maintain = True
anno['duration'] = anno['conversation'][-1]['time'] - anno['conversation'][0]['time']
if anno['duration'] < 60 or anno['duration'] > 3600:
continue
for message in anno['conversation']:
if 'second' in message['content'] or re.match(r'\b\d+s\b', message['content']): # if the generated content contains time related text, it may leak the future ground-truth
maintain = False
break
if maintain:
new_annos.append(anno)
39 changes: 39 additions & 0 deletions data/preprocess/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
### Distributed Preprocess Video Frames for VideoLLM-online

#### Sample video frames to 2 FPS and max resolution 384 (with zero padding)

```
python -m data.preprocess.ffmpeg --num_gpus 8 --frame_fps 2 --frame_resolution 384 --video_dir datasets/ego4d/v2/full_scale
```

- Please run the script in ```videollm-online/``` root folder.

- The results will be saved in a new folder with '{fps}fps_{resolution}' suffix. For example, ```datasets/ego4d/v2/full_scale -> datasets/ego4d/v2/full_scale_2fps_384```.

- If you are on a cluster, you can set ```--num_nodes ... --slurm_partition ...``` to use them. The more nodes and GPUs, the faster preprocessing.

#### Encode sampled 2fps_384 video frames

```
python -m data.preprocess.encode --num_gpus 8 --video_dir datasets/ego4d/v2/full_scale_2fps_384 --vision_pretrained google/siglip-large-patch16-384
```

- Please run the script in ```videollm-online/``` root folder.

- The results will be saved in a new folder with '{embed_mark}_{model}' suffix. For example, ```datasets/ego4d/v2/full_scale_2fps_384 -> datasets/ego4d/v2/full_scale_2fps_384_1+3x3_google--siglip-large-patch16-384```.

- If you are on a cluster, you can set ```--num_nodes ... --slurm_partition ...``` to use them. The more nodes and GPUs, the faster preprocessing.

#### Narration Refinement

```
python -m data.preprocess.ego4d_narration_refinement --llm_pretrained meta-llama/Meta-Llama-3-8B-Instruct --anno_root datasets/ego4d/v2/annotations --split train
python -m data.preprocess.ego4d_narration_refinement --llm_pretrained meta-llama/Meta-Llama-3-8B-Instruct --anno_root datasets/ego4d/v2/annotations --split val
```

- Please run the script in ```videollm-online/``` root folder.

- The results will be saved in a new json of 'refined_narration_stream_{args.split}' name. For example, ```datasets/ego4d/v2/annotations/narration_stream_train.json -> datasets/ego4d/v2/annotations/refined_narration_stream_train.json```.

- If you are on a cluster, you can set ```--num_nodes ... --slurm_partition ...``` to use them. The more nodes and GPUs, the faster preprocessing.
72 changes: 72 additions & 0 deletions data/preprocess/ego4d_narration_refinement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json, torch, tqdm, os, functools, submitit
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from dataclasses import dataclass

from models.arguments_live import LiveOnePlusTrainingArguments

@dataclass
class LiveOnePlusEncodingArguments(LiveOnePlusTrainingArguments):
num_nodes: int = 1
num_gpus: int = 8
anno_root: str = 'datasets/ego4d/v2/annotations'
split: str = 'train'

@torch.no_grad()
def distributed_refine_narration(args: LiveOnePlusEncodingArguments):
env = submitit.JobEnvironment()
torch.cuda.set_device(env.local_rank)

model = AutoModelForCausalLM.from_pretrained(args.llm_pretrained, torch_dtype='auto', attn_implementation='sdpa')
tokenizer = AutoTokenizer.from_pretrained(args.llm_pretrained, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
model.eval()
model.to('cuda')
generator = functools.partial(model.generate, max_new_tokens=64, do_sample=False, top_p=1.0, temperature=1.0, use_cache=True, pad_token_id=tokenizer.pad_token_id)

anno_path = os.path.join(args.ego4d_anno_root, f'narration_stream_{args.split}.json')
save_dir = os.path.join(args.ego4d_anno_root, f'refined_narration_stream_{args.split}')
annos = json.load(open(anno_path))
os.makedirs(save_dir, exist_ok=True)
mapping = {}

annos = {video_uid: _annotation_uid_narrations for i, (video_uid, _annotation_uid_narrations) in tqdm.tqdm(enumerate(annos.items())) if not os.path.exists(os.path.join(save_dir, f'{video_uid}.json'))}
for i, (video_uid, _annotation_uid_narrations) in tqdm.tqdm(enumerate(annos.items())):
if i % env.num_tasks != env.global_rank:
continue
save_path = os.path.join(save_dir, f'{video_uid}.json')
for _annotation_uid, narrations in _annotation_uid_narrations.items():
for narration in narrations:
if narration['text'] not in mapping:
chat = [
{
"role": "user", "content": ("Please help me to refine the text, e.g., [C looks around.] -> [You look around.]"
"In the text, There are many uppercase letters to denote persons. Rewrite the sentence to avoid these uppercase letters, improve the text quality, make the text clear and concise. "
"For example:\n[C looks around.] -> [You look around.]\n[A man X watches the phone.] -> [A man watches the phone.]\n"
f"[C plays a piano, and a woman O comes to him.] -> [You play a piano, and a woman comes to you.]\n[Man A approaches C] -> [A man approaches you.]\n\nNow, please refine [{narration['text']}] -> ?, make the answer in [].")
},
{"role": "assistant", "content": f"[{narration['text']}] -> ["}
]
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, return_tensors='pt')[:,:-1].cuda()
output_ids = generator(input_ids)[:, input_ids.size(1):]
text = tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
try:
mapping[narration['text']] = text[:text.index(']')]
except:
print('fuck', narration['text'], text)
mapping[narration['text']] = 'Not sure what you are doing.'
narration['text'] = mapping[narration['text']]

json.dump(_annotation_uid_narrations, open(save_path, 'w'), indent=4)

if __name__ == "__main__":
args, = HfArgumentParser(LiveOnePlusEncodingArguments).parse_args_into_dataclasses()
executor = submitit.AutoExecutor(folder=f"outputs/preprocess/")
executor.update_parameters(
tasks_per_node=args.num_gpus,
nodes=args.num_nodes,
gpus_per_node=args.num_gpus,
cpus_per_task=10,
mem_gb=240,
slurm_time='24:00:00',
)
job = executor.submit(distributed_refine_narration, args)
37 changes: 37 additions & 0 deletions data/preprocess/encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import submitit, functools, transformers
from dataclasses import asdict, dataclass
from models.vision_live import build_live_vision

from models.configuration_live import LiveConfigMixin
from models.arguments_live import LiveOnePlusTrainingArguments
from ..utils import distributed_encode

@dataclass
class LiveOnePlusEncodingArguments(LiveOnePlusTrainingArguments):
num_nodes: int = 1
num_gpus: int = 8
video_dir: str = 'datasets/ego4d/v2/full_scale_2fps_384'
slurm_partition: str = None

if __name__ == "__main__":
args, = transformers.HfArgumentParser(LiveOnePlusEncodingArguments).parse_args_into_dataclasses()
vision_config = LiveConfigMixin(**asdict(args))
_, vision_encode = build_live_vision(vision_config)
task = functools.partial(
distributed_encode, src_root=args.video_dir,
vision_pretrained=args.vision_pretrained,
embed_mark=args.embed_mark,
vision_encode=vision_encode,
batch_size=256, save_bf16=True
)
executor = submitit.AutoExecutor(folder=f"outputs/preprocess/")
executor.update_parameters(
tasks_per_node=args.num_gpus,
nodes=args.num_nodes,
gpus_per_node=args.num_gpus,
cpus_per_task=10,
slurm_partition=args.slurm_partition,
mem_gb=240,
slurm_time='24:00:00',
)
job = executor.submit(task)
28 changes: 28 additions & 0 deletions data/preprocess/ffmpeg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from functools import partial
import submitit, transformers
from dataclasses import dataclass

from models.arguments_live import LiveOnePlusTrainingArguments
from ..utils import distributed_ffmpeg

@dataclass
class LiveOnePlusEncodingArguments(LiveOnePlusTrainingArguments):
num_nodes: int = 1
num_gpus: int = 8
video_dir: str = 'datasets/ego4d/v2/full_scale'
slurm_partition: str = None

if __name__ == "__main__":
args, = transformers.HfArgumentParser(LiveOnePlusEncodingArguments).parse_args_into_dataclasses()
executor = submitit.AutoExecutor(folder=f"outputs/preprocess/")
task = partial(distributed_ffmpeg, src_root=args.video_dir, resolution=args.frame_resolution, fps=args.frame_fps)
executor.update_parameters(
tasks_per_node=args.num_gpus,
nodes=args.num_nodes,
gpus_per_node=args.num_gpus,
slurm_partition=args.slurm_partition,
cpus_per_task=10,
mem_gb=240,
slurm_time='24:00:00',
)
job = executor.submit(task)
8 changes: 4 additions & 4 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import random, torch, tqdm, os, subprocess, torchvision, pathlib, json, math
import random, torch, tqdm, os, subprocess, torchvision, pathlib, submitit, math
from itertools import takewhile
try:
torchvision.set_video_backend('video_reader')
Expand Down Expand Up @@ -76,19 +76,19 @@ def distributed_ffmpeg(*, src_root: str, fps: int = None, resolution: int = None
dst_root += f'_{fps}fps'
if resolution is not None:
assert (pad is not None)
dst_root += f'_max{resolution}_pad{pad}'
dst_root += f'_max{resolution}'
for i, src_path in tqdm.tqdm(enumerate(src_paths), desc=f'{src_root} -> {dst_root}'):
if i % env.num_tasks != env.global_rank:
continue
dst_path = src_path.replace(src_root, dst_root)
ffmpeg_once(src_path, dst_path, fps=fps, resolution=resolution, pad=pad)

def distributed_encode(*, src_root: str, vision_pretrained: str, vision_encode: callable, batch_size: int, tokens: str, save_bf16: bool = False, **kwargs):
def distributed_encode(*, src_root: str, vision_pretrained: str, vision_encode: callable, batch_size: int, embed_mark: str, save_bf16: bool = False, **kwargs):
env = submitit.JobEnvironment()
src_root = src_root.rstrip('/')
model = AutoModel.from_pretrained(vision_pretrained, device_map=f'cuda:{env.local_rank}').vision_model
model.eval()
dst_root = f"{src_root}_{tokens}_{vision_pretrained.replace('/', '--')}"
dst_root = f"{src_root}_{embed_mark}_{vision_pretrained.replace('/', '--')}"
os.makedirs(dst_root, exist_ok=True)
for i, file in tqdm.tqdm(enumerate(os.listdir(src_root)), desc=f'{src_root} -> {dst_root}'):
if i % env.num_tasks != env.global_rank:
Expand Down
2 changes: 1 addition & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
def parse_args() -> LiveTrainingArguments:
args, = HfArgumentParser(LiveTrainingArguments).parse_args_into_dataclasses()
args, = HfArgumentParser(get_args_class(args.live_version)).parse_args_into_dataclasses()
return args
return args
Loading

0 comments on commit d9a721f

Please sign in to comment.