You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I wanna convert pth to onnx format. This is my code:
import torch
from model.faster_rcnn.vgg16 import vgg16
from model.faster_rcnn.resnet import resnet
import numpy as np
from torch.autograd import Variable
def load_model(model, pretrained_path):
print('Loading pretrained model from {}'.format(pretrained_path))
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
model.load_state_dict(pretrained_dict, strict=False)
return model
net = resnet(pascal_classes, 101, pretrained=False, class_agnostic=False)
net.create_architecture()
checkpoint = torch.load(raw_weights)
for k in checkpoint.keys():
print(k)
net.load_state_dict(checkpoint['model'])
but when I run it, I got this error. How can I fix it?Thanks!
Traceback (most recent call last):
File "pth2onnx.py", line 67, in
torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/init.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 701, in _export
dynamic_axes=dynamic_axes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 459, in _model_to_graph
use_new_jit_passes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 130, in forward
self._force_outplace,
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() missing 3 required positional arguments: 'im_info', 'gt_boxes', and 'num_boxes'
The text was updated successfully, but these errors were encountered:
I wanna convert pth to onnx format. This is my code:
import torch
from model.faster_rcnn.vgg16 import vgg16
from model.faster_rcnn.resnet import resnet
import numpy as np
from torch.autograd import Variable
def load_model(model, pretrained_path):
print('Loading pretrained model from {}'.format(pretrained_path))
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
model.load_state_dict(pretrained_dict, strict=False)
return model
output_onnx = './output.onnx'
raw_weights = './faster_rcnn_1_10_2504.pth'
pascal_classes = np.asarray(['background',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor'])
load weight
net = resnet(pascal_classes, 101, pretrained=False, class_agnostic=False)
net.create_architecture()
checkpoint = torch.load(raw_weights)
for k in checkpoint.keys():
print(k)
net.load_state_dict(checkpoint['model'])
initilize the tensor holder here.
im_data = torch.FloatTensor(1)
im_info = torch.FloatTensor(1)
num_boxes = torch.LongTensor(1)
gt_boxes = torch.FloatTensor(1)
ship to cuda
im_data = im_data.cuda()
im_info = im_info.cuda()
num_boxes = num_boxes.cuda()
gt_boxes = gt_boxes.cuda()
make variable
im_data = Variable(im_data, volatile=True)
im_info = Variable(im_info, volatile=True)
num_boxes = Variable(num_boxes, volatile=True)
gt_boxes = Variable(gt_boxes, volatile=True)
net.eval()
print('Finished loading model!')
device = torch.device("cuda")
net = net.to(device)
input_names = ["input0"]
output_names = ["output0"]
inputs = torch.randn(1, 3, 300, 300).to(device)
output model
torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
but when I run it, I got this error. How can I fix it?Thanks!
Traceback (most recent call last):
File "pth2onnx.py", line 67, in
torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/init.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 701, in _export
dynamic_axes=dynamic_axes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 459, in _model_to_graph
use_new_jit_passes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 130, in forward
self._force_outplace,
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() missing 3 required positional arguments: 'im_info', 'gt_boxes', and 'num_boxes'
The text was updated successfully, but these errors were encountered: