-
Notifications
You must be signed in to change notification settings - Fork 6
/
recognizer2d.py
185 lines (153 loc) · 6.42 KB
/
recognizer2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import torch
from torch import nn
from ..builder import RECOGNIZERS
from .base import BaseRecognizer
@RECOGNIZERS.register_module()
class Recognizer2D(BaseRecognizer):
"""2D recognizer model framework."""
def forward_train(self, imgs, labels, **kwargs):
"""Defines the computation performed at every call when training."""
assert self.with_cls_head
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
losses = dict()
x = self.extract_feat(imgs)
if self.backbone_from in ['torchvision', 'timm']:
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, loss_aux = self.neck(x, labels.squeeze())
x = x.squeeze(2)
num_segs = 1
losses.update(loss_aux)
cls_score = self.cls_head(x, num_segs)
gt_labels = labels.squeeze()
loss_cls = self.cls_head.loss(cls_score, gt_labels, **kwargs)
losses.update(loss_cls)
return losses
def _do_test(self, imgs):
"""Defines the computation performed at every call when evaluation,
testing and gradcam."""
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
x = self.extract_feat(imgs)
if self.backbone_from in ['torchvision', 'timm']:
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, _ = self.neck(x)
x = x.squeeze(2)
num_segs = 1
if self.feature_extraction:
# perform spatial pooling
avg_pool = nn.AdaptiveAvgPool2d(1)
x = avg_pool(x)
# squeeze dimensions
x = x.reshape((batches, num_segs, -1))
# temporal average pooling
x = x.mean(axis=1)
return x
# When using `TSNHead` or `TPNHead`, shape is [batch_size, num_classes]
# When using `TSMHead`, shape is [batch_size * num_crops, num_classes]
# `num_crops` is calculated by:
# 1) `twice_sample` in `SampleFrames`
# 2) `num_sample_positions` in `DenseSampleFrames`
# 3) `ThreeCrop/TenCrop/MultiGroupCrop` in `test_pipeline`
# 4) `num_clips` in `SampleFrames` or its subclass if `clip_len != 1`
# should have cls_head if not extracting features
cls_score = self.cls_head(x, num_segs)
assert cls_score.size()[0] % batches == 0
# calculate num_crops automatically
cls_score = self.average_clip(cls_score,
cls_score.size()[0] // batches)
return cls_score
def _do_fcn_test(self, imgs):
# [N, num_crops * num_segs, C, H, W] ->
# [N * num_crops * num_segs, C, H, W]
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = self.test_cfg.get('num_segs', self.backbone.num_segments)
if self.test_cfg.get('flip', False):
imgs = torch.flip(imgs, [-1])
x = self.extract_feat(imgs)
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, _ = self.neck(x)
else:
x = x.reshape((-1, num_segs) +
x.shape[1:]).transpose(1, 2).contiguous()
# When using `TSNHead` or `TPNHead`, shape is [batch_size, num_classes]
# When using `TSMHead`, shape is [batch_size * num_crops, num_classes]
# `num_crops` is calculated by:
# 1) `twice_sample` in `SampleFrames`
# 2) `num_sample_positions` in `DenseSampleFrames`
# 3) `ThreeCrop/TenCrop/MultiGroupCrop` in `test_pipeline`
# 4) `num_clips` in `SampleFrames` or its subclass if `clip_len != 1`
cls_score = self.cls_head(x, fcn_test=True)
assert cls_score.size()[0] % batches == 0
# calculate num_crops automatically
cls_score = self.average_clip(cls_score,
cls_score.size()[0] // batches)
return cls_score
def forward_test(self, imgs):
"""Defines the computation performed at every call when evaluation and
testing."""
if self.test_cfg.get('fcn_test', False):
# If specified, spatially fully-convolutional testing is performed
assert not self.feature_extraction
assert self.with_cls_head
return self._do_fcn_test(imgs).cpu().numpy()
return self._do_test(imgs).cpu().numpy()
def forward_dummy(self, imgs, softmax=False):
"""Used for computing network FLOPs.
See ``tools/analysis/get_flops.py``.
Args:
imgs (torch.Tensor): Input images.
Returns:
Tensor: Class score.
"""
assert self.with_cls_head
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
x = self.extract_feat(imgs)
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, _ = self.neck(x)
x = x.squeeze(2)
num_segs = 1
outs = self.cls_head(x, num_segs)
if softmax:
outs = nn.functional.softmax(outs)
return (outs, )
def forward_gradcam(self, imgs):
"""Defines the computation performed at every call when using gradcam
utils."""
assert self.with_cls_head
return self._do_test(imgs)