You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
修改dbnetpp模型,希望将neck中的fpn换成pfpn,在pfpn代码中加入如下代码:
for i, out in enumerate(outs):
outs[i] = F.interpolate(
outs[i], size=outs[0].shape[2:], mode='nearest') # list
out = torch.cat(outs, dim=1)
print(out.shape) # torch.Size([8, 1280, 160, 160])
if self.asf_cfg is not None:
asf_feature = self.asf_conv(out)
attention = self.asf_attn(asf_feature)
enhanced_feature = []
for i, out in enumerate(out):
enhanced_feature.append(attention[:, i:i + 1] * out[i])
out = torch.cat(enhanced_feature, dim=1)
print(out.shape) # torch.Size([8, 5, 160, 160])
if self.conv_after_concat:
out = self.out_conv(out)
print(out.shape) # torch.Size([8, 5, 160, 160])
return tuple(out)
报错:
Traceback (most recent call last):
File "tools/train.py", line 231, in
main()
File "tools/train.py", line 220, in main
train_detector(
File "/root/mmocr/mmocr/apis/train.py", line 155, in train_detector
runner.run(data_loaders, cfg.workflow)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 136, in run
epoch_runner(data_loaders[i], **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 53, in train
self.run_iter(data_batch, train_mode=True, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 31, in run_iter
outputs = self.model.train_step(data_batch, self.optimizer,
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/data_parallel.py", line 77, in train_step
return self.module.train_step(*inputs[0], **kwargs[0])
File "/root/miniconda3/lib/python3.8/site-packages/mmdet/models/detectors/base.py", line 248, in train_step
losses = self(**data)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 116, in new_func
return old_func(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmdet/models/detectors/base.py", line 172, in forward
return self.forward_train(img, img_metas, **kwargs)
File "/root/mmocr/mmocr/models/textdet/detectors/single_stage_text_detector.py", line 37, in forward_train
preds = self.bbox_head(x[0])
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/mmocr/mmocr/models/textdet/dense_heads/db_head.py", line 80, in forward
prob_map = self.binarize(inputs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 446, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 256, 3, 3], but got 3-dimensional input of size [5, 160, 160] instead
请问该如何解决这个问题,另外,如果不要新增代码(即不加入特征融合模块),模型能够顺利构建并训练,希望能得到一些建议,感谢回复。
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
修改dbnetpp模型,希望将neck中的fpn换成pfpn,在pfpn代码中加入如下代码:
for i, out in enumerate(outs):
outs[i] = F.interpolate(
outs[i], size=outs[0].shape[2:], mode='nearest') # list
out = torch.cat(outs, dim=1)
print(out.shape) # torch.Size([8, 1280, 160, 160])
if self.asf_cfg is not None:
asf_feature = self.asf_conv(out)
attention = self.asf_attn(asf_feature)
enhanced_feature = []
for i, out in enumerate(out):
enhanced_feature.append(attention[:, i:i + 1] * out[i])
out = torch.cat(enhanced_feature, dim=1)
print(out.shape) # torch.Size([8, 5, 160, 160])
if self.conv_after_concat:
out = self.out_conv(out)
print(out.shape) # torch.Size([8, 5, 160, 160])
return tuple(out)
报错:
Traceback (most recent call last):
File "tools/train.py", line 231, in
main()
File "tools/train.py", line 220, in main
train_detector(
File "/root/mmocr/mmocr/apis/train.py", line 155, in train_detector
runner.run(data_loaders, cfg.workflow)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 136, in run
epoch_runner(data_loaders[i], **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 53, in train
self.run_iter(data_batch, train_mode=True, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 31, in run_iter
outputs = self.model.train_step(data_batch, self.optimizer,
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/data_parallel.py", line 77, in train_step
return self.module.train_step(*inputs[0], **kwargs[0])
File "/root/miniconda3/lib/python3.8/site-packages/mmdet/models/detectors/base.py", line 248, in train_step
losses = self(**data)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 116, in new_func
return old_func(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmdet/models/detectors/base.py", line 172, in forward
return self.forward_train(img, img_metas, **kwargs)
File "/root/mmocr/mmocr/models/textdet/detectors/single_stage_text_detector.py", line 37, in forward_train
preds = self.bbox_head(x[0])
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/mmocr/mmocr/models/textdet/dense_heads/db_head.py", line 80, in forward
prob_map = self.binarize(inputs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 446, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 256, 3, 3], but got 3-dimensional input of size [5, 160, 160] instead
请问该如何解决这个问题,另外,如果不要新增代码(即不加入特征融合模块),模型能够顺利构建并训练,希望能得到一些建议,感谢回复。
Beta Was this translation helpful? Give feedback.
All reactions