-
Notifications
You must be signed in to change notification settings - Fork 0
/
ort_1.py
38 lines (31 loc) · 1.13 KB
/
ort_1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#!/usr/bin/env python
# coding=utf-8
import onnxruntime as ort
import numpy as np
from pred import build_predict_text, key
def get_ort_session(model_path, providers = None):
if providers is None:
providers = ort.get_available_providers()
return [ort.InferenceSession(model_path, providers=[provider]) for provider in providers]
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def predict(sess, text):
ids, seq_len, mask = build_predict_text(t)
input = {
sess.get_inputs()[0].name: to_numpy(ids),
sess.get_inputs()[1].name: to_numpy(mask),
}
outs = sess.run(None, input)
num = np.argmax(outs)
return key[num]
if __name__ == '__main__':
sess = ort.InferenceSession("./model.onnx", providers=['CUDAExecutionProvider'])
ts = [
"李稻葵:过去2年抗疫为每人增寿10天",
"4个小学生离家出走30公里想去广州塔",
"朱一龙戏路打通电影电视剧",
"天问一号着陆火星一周年",
]
for t in ts:
res = predict(sess, t)
print("%s is %s" % (t, res))