-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict_p3.py
80 lines (66 loc) · 2.74 KB
/
predict_p3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from modules.utils import set_random_seed
from modules.dataset import FullLengthVideosDataset
from modules.models.LRCNN import SeqRecurrentCNN
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Predict full-length videos.")
parser.add_argument("video_path", type=str, help="Path to video directory.")
parser.add_argument("model_dir", type=str, help="Where to load trained model.")
parser.add_argument("output_dir", type=str, help="Output directory.")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size.")
parser.add_argument(
"--max_padding", type=int, default=30, help="Max padding length of frames."
)
parser.add_argument("--ds_factor", type=int, default=12, help="Down-sample factor.")
parser.add_argument("--rescale_factor", type=int, default=1, help="Rescale factor.")
parser.add_argument(
"--sorting",
action="store_true",
help="Whether to sort by video length per batch.",
)
parser.add_argument(
"--n_workers", type=int, default=8, help="Number of worker for dataloader."
)
parser.add_argument("--random_seed", type=int, default=42, help="Random seed.")
args = parser.parse_args()
set_random_seed(args.random_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# prepare dataset
dataset = FullLengthVideosDataset(
args.video_path, None, length=100, overlap=0, sorting=True, test=True
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.n_workers,
shuffle=False,
collate_fn=dataset.collate_fn,
)
# prepare model
model = SeqRecurrentCNN()
model.load_state_dict(torch.load(args.model_dir))
model.to(device)
model.eval()
preds_ls = {}
with torch.no_grad():
for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
frames = batch["frames"].to(device)
frames_len = batch["frames_len"]
sorted_idx = batch["sorted_idx"]
category = batch["category"]
preds = model(frames, frames_len)
preds = torch.exp(preds).max(dim=2)[1].detach().cpu().numpy()
for _idx in sorted_idx:
cat = category[_idx][0]
if cat not in preds_ls:
preds_ls[cat] = []
preds_ls[cat] += preds[_idx][: frames_len[_idx]].tolist()
for k, v in preds_ls.items():
output_dir = os.path.join(args.output_dir, "{}.txt".format(k))
with open(output_dir, "w") as fout:
for label in v:
fout.write(str(label) + "\n")