forked from bubbliiiing/pspnet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pspnet.py
108 lines (94 loc) · 4.57 KB
/
pspnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from nets.pspnet import PSPNet as pspnet
from torch import nn
from PIL import Image
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import colorsys
import torch
import copy
import os
class PSPNet(object):
#-----------------------------------------#
# 注意修改model_path、num_classes
# 和backbone
# 使其符合自己的模型
#-----------------------------------------#
_defaults = {
"model_path" : 'model_data/pspnet_mobilenetv2.pth',
"model_image_size" : (473, 473, 3),
"backbone" : "mobilenet",
"downsample_factor" : 16,
"num_classes" : 21,
"cuda" : True,
"blend" : True,
}
#---------------------------------------------------#
# 初始化UNET
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
self.generate()
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
def generate(self):
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
self.net = pspnet(num_classes=self.num_classes, downsample_factor=self.downsample_factor, pretrained=False, backbone=self.backbone, aux_branch=False)
self.net = self.net.eval()
state_dict = torch.load(self.model_path)
self.net.load_state_dict(state_dict, strict=False)
if self.cuda:
self.net = nn.DataParallel(self.net)
self.net = self.net.cuda()
print('{} model, anchors, and classes loaded.'.format(self.model_path))
# 画框设置不同的颜色
if self.num_classes <= 21:
self.colors = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), (128, 64, 12)]
else:
# 画框设置不同的颜色
hsv_tuples = [(x / len(self.class_names), 1., 1.)
for x in range(len(self.class_names))]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
self.colors))
def letterbox_image(self ,image, size):
'''resize image with unchanged aspect ratio using padding'''
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image,nw,nh
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image):
old_img = copy.deepcopy(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
image, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
images = [np.array(image)/255]
images = np.transpose(images,(0,3,1,2))
with torch.no_grad():
images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
if self.cuda:
images =images.cuda()
pr = self.net(images)[0]
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
for c in range(self.num_classes):
seg_img[:,:,0] += ((pr[:,: ] == c )*( self.colors[c][0] )).astype('uint8')
seg_img[:,:,1] += ((pr[:,: ] == c )*( self.colors[c][1] )).astype('uint8')
seg_img[:,:,2] += ((pr[:,: ] == c )*( self.colors[c][2] )).astype('uint8')
image = Image.fromarray(np.uint8(seg_img)).resize((orininal_w,orininal_h))
if self.blend:
image = Image.blend(old_img,image,0.7)
return image