-
Notifications
You must be signed in to change notification settings - Fork 23
/
test.py
executable file
·117 lines (92 loc) · 4.23 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""
# -*- coding: utf-8 -*-
import os
import time
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import cv2
import numpy as np
from utils import file_utils, craft_utils, imgproc
from net.craft import CRAFT
from eval import copyStateDict
def str2bool(v):
return v.lower() in ("yes", "y", "true", "t", "1")
parser = argparse.ArgumentParser(description='CRAFT Text Detection')
parser.add_argument('--trained_model', default='final_net_param.pth', type=str, help='pretrained model')
parser.add_argument('--text_threshold', default=0.5, type=float, help='text confidence threshold')
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
parser.add_argument('--cuda', default=False, type=str2bool, help='Use cuda to train model')
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
parser.add_argument('--test_folder', default='/home/brooklyn/ICDAR/icdar2013/test_images/', type=str, help='folder path to input images')
parser.add_argument('--result_folder', default='./result/', type=str, help='folder path to save result images')
args = parser.parse_args()
""" For test images in a folder """
image_list, _, _ = file_utils.get_files(args.test_folder)
#测试结果保存路径
result_folder = args.result_folder
if not os.path.isdir(result_folder):
os.mkdir(result_folder)
def test_net(net, image, text_threshold, link_threshold, low_text, cuda):
t0 = time.time()
# resize
img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)
ratio_h = ratio_w = 1 / target_ratio
# preprocessing
x = imgproc.normalizeMeanVariance(img_resized)
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
if cuda:
x = x.cuda()
# forward pass
with torch.no_grad():
y, _ = net(x)
# make score and link map
score_text = y[0,:,:,0].cpu().data.numpy()
score_link = y[0,:,:,1].cpu().data.numpy()
t0 = time.time() - t0
t1 = time.time()
# Post-processing
boxes = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text)
boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
t1 = time.time() - t1
# render results (optional)
render_img = score_text.copy()
render_img = np.hstack((render_img, score_link))
ret_score_text = imgproc.cvt2HeatmapImg(render_img)
if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
return boxes, ret_score_text
if __name__ == '__main__':
# load net
net = CRAFT() # initialize
print('Loading weights from checkpoint (' + args.trained_model + ')')
if args.cuda:
net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
net = net.cuda()
net = torch.nn.DataParallel(net)
cudnn.benchmark = False
else:
net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))
net.eval()
t = time.time()
print("net.eval")
print(image_list)
# load data
for k, image_path in enumerate(image_list):
print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
image = imgproc.loadImage(image_path)
bboxes, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda)
# save score text
filename, file_ext = os.path.splitext(os.path.basename(image_path))
mask_file = result_folder + "/res_" + filename + '_mask.jpg'
cv2.imwrite(mask_file, score_text)
file_utils.saveResult(image_path, image[:, :, ::-1], bboxes, dirname=result_folder)
print("elapsed time : {}s".format(time.time() - t))