-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f95d1d6
commit 3409b41
Showing
25 changed files
with
1,498 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
存放的是指向文件名称的txt | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
这里面存放的是训练用的图片文件。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
这里面存放的是训练过程中产生的权重。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
import random | ||
|
||
segfilepath=r'./VOCdevkit/VOC2007/SegmentationClass' | ||
saveBasePath=r"./VOCdevkit/VOC2007/ImageSets/Segmentation/" | ||
|
||
trainval_percent=1 | ||
train_percent=0.9 | ||
|
||
temp_seg = os.listdir(segfilepath) | ||
total_seg = [] | ||
for seg in temp_seg: | ||
if seg.endswith(".png"): | ||
total_seg.append(seg) | ||
|
||
num=len(total_seg) | ||
list=range(num) | ||
tv=int(num*trainval_percent) | ||
tr=int(tv*train_percent) | ||
trainval= random.sample(list,tv) | ||
train=random.sample(trainval,tr) | ||
|
||
print("train and val size",tv) | ||
print("traub suze",tr) | ||
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') | ||
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') | ||
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') | ||
fval = open(os.path.join(saveBasePath,'val.txt'), 'w') | ||
|
||
for i in list: | ||
name=total_seg[i][:-4]+'\n' | ||
if i in trainval: | ||
ftrainval.write(name) | ||
if i in train: | ||
ftrain.write(name) | ||
else: | ||
fval.write(name) | ||
else: | ||
ftest.write(name) | ||
|
||
ftrainval.close() | ||
ftrain.close() | ||
fval.close() | ||
ftest .close() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from pspnet import PSPNet | ||
from torch import nn | ||
from PIL import Image | ||
from torch.autograd import Variable | ||
import torch.nn.functional as F | ||
import numpy as np | ||
import colorsys | ||
import torch | ||
import copy | ||
import os | ||
|
||
class miou_Pspnet(PSPNet): | ||
def detect_image(self, image): | ||
orininal_h = np.array(image).shape[0] | ||
orininal_w = np.array(image).shape[1] | ||
|
||
image, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0])) | ||
images = [np.array(image)/255] | ||
images = np.transpose(images,(0,3,1,2)) | ||
|
||
with torch.no_grad(): | ||
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)) | ||
if self.cuda: | ||
images = images.cuda() | ||
pr = self.net(images)[0] | ||
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1) | ||
|
||
pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)] | ||
|
||
image = Image.fromarray(np.uint8(pr)).resize((orininal_w,orininal_h),Image.NEAREST) | ||
|
||
return image | ||
|
||
pspnet = miou_Pspnet() | ||
|
||
image_ids = open(r"VOCdevkit\VOC2007\ImageSets\Segmentation\val.txt",'r').read().splitlines() | ||
|
||
if not os.path.exists("./miou_pr_dir"): | ||
os.makedirs("./miou_pr_dir") | ||
|
||
for image_id in image_ids: | ||
image_path = "./VOCdevkit/VOC2007/JPEGImages/"+image_id+".jpg" | ||
image = Image.open(image_path) | ||
image = pspnet.detect_image(image) | ||
image.save("./miou_pr_dir/" + image_id + ".png") | ||
print(image_id," done!") |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
import json | ||
import os | ||
import os.path as osp | ||
import warnings | ||
|
||
import PIL.Image | ||
import yaml | ||
import numpy as np | ||
from labelme import utils | ||
import base64 | ||
|
||
if __name__ == '__main__': | ||
jpgs_path = "datasets/JPEGImages" | ||
pngs_path = "datasets/SegmentationClass" | ||
classes = ["_background_","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] | ||
# classes = ["_background_","cat","dog"] | ||
|
||
count = os.listdir("./datasets/before/") | ||
for i in range(0, len(count)): | ||
path = os.path.join("./datasets/before", count[i]) | ||
|
||
if os.path.isfile(path) and path.endswith('json'): | ||
data = json.load(open(path)) | ||
|
||
if data['imageData']: | ||
imageData = data['imageData'] | ||
else: | ||
imagePath = os.path.join(os.path.dirname(path), data['imagePath']) | ||
with open(imagePath, 'rb') as f: | ||
imageData = f.read() | ||
imageData = base64.b64encode(imageData).decode('utf-8') | ||
|
||
img = utils.img_b64_to_arr(imageData) | ||
label_name_to_value = {'_background_': 0} | ||
for shape in data['shapes']: | ||
label_name = shape['label'] | ||
if label_name in label_name_to_value: | ||
label_value = label_name_to_value[label_name] | ||
else: | ||
label_value = len(label_name_to_value) | ||
label_name_to_value[label_name] = label_value | ||
|
||
# label_values must be dense | ||
label_values, label_names = [], [] | ||
for ln, lv in sorted(label_name_to_value.items(), key=lambda x: x[1]): | ||
label_values.append(lv) | ||
label_names.append(ln) | ||
assert label_values == list(range(len(label_values))) | ||
|
||
lbl = utils.shapes_to_label(img.shape, data['shapes'], label_name_to_value) | ||
|
||
|
||
PIL.Image.fromarray(img).save(osp.join(jpgs_path, count[i].split(".")[0]+'.jpg')) | ||
|
||
new = np.zeros([np.shape(img)[0],np.shape(img)[1]]) | ||
for name in label_names: | ||
index_json = label_names.index(name) | ||
index_all = classes.index(name) | ||
new = new + index_all*(np.array(lbl) == index_json) | ||
|
||
utils.lblsave(osp.join(pngs_path, count[i].split(".")[0]+'.png'), new) | ||
print('Saved ' + count[i].split(".")[0] + '.jpg and ' + count[i].split(".")[0] + '.png') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
这一部分用来存放训练后的文件。 | ||
This part is used to store post training documents. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import argparse | ||
import json | ||
from PIL import Image | ||
from os.path import join | ||
|
||
# 设标签宽W,长H | ||
def fast_hist(a, b, n): | ||
# a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的标签,形状(H×W,) | ||
k = (a >= 0) & (a < n) | ||
# np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) | ||
# 返回中,写对角线上的为分类正确的像素点 | ||
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) | ||
|
||
def per_class_iu(hist): | ||
# 矩阵的对角线上的值组成的一维数组/矩阵的所有元素之和,返回值形状(n,) | ||
return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) | ||
|
||
def per_class_PA(hist): | ||
# 矩阵的对角线上的值组成的一维数组/矩阵的所有元素之和,返回值形状(n,) | ||
return np.diag(hist) / hist.sum(1) | ||
|
||
def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes): | ||
# 计算mIoU的函数 | ||
print('Num classes', num_classes) | ||
## 1 | ||
hist = np.zeros((num_classes, num_classes)) | ||
|
||
gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list] # 获得验证集标签路径列表,方便直接读取 | ||
pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list] # 获得验证集图像分割结果路径列表,方便直接读取 | ||
|
||
# 读取每一个(图片-标签)对 | ||
for ind in range(len(gt_imgs)): | ||
# 读取一张图像分割结果,转化成numpy数组 | ||
pred = np.array(Image.open(pred_imgs[ind])) | ||
# 读取一张对应的标签,转化成numpy数组 | ||
label = np.array(Image.open(gt_imgs[ind])) | ||
|
||
# 如果图像分割结果与标签的大小不一样,这张图片就不计算 | ||
if len(label.flatten()) != len(pred.flatten()): | ||
print( | ||
'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format( | ||
len(label.flatten()), len(pred.flatten()), gt_imgs[ind], | ||
pred_imgs[ind])) | ||
continue | ||
# 对一张图片计算19×19的hist矩阵,并累加 | ||
hist += fast_hist(label.flatten(), pred.flatten(),num_classes) | ||
# 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值 | ||
if ind > 0 and ind % 10 == 0: | ||
print('{:d} / {:d}: mIou-{:0.2f}; mPA-{:0.2f}'.format(ind, len(gt_imgs), | ||
100 * np.mean(per_class_iu(hist)), | ||
100 * np.mean(per_class_PA(hist)))) | ||
# 计算所有验证集图片的逐类别mIoU值 | ||
mIoUs = per_class_iu(hist) | ||
mPA = per_class_PA(hist) | ||
# 逐类别输出一下mIoU值 | ||
for ind_class in range(num_classes): | ||
print('===>' + name_classes[ind_class] + ':\tmIou-' + str(round(mIoUs[ind_class] * 100, 2)) + '; mPA-' + str(round(mPA[ind_class] * 100, 2))) | ||
# 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值 | ||
print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(mPA) * 100, 2))) | ||
return mIoUs | ||
|
||
|
||
if __name__ == "__main__": | ||
gt_dir = "./VOCdevkit/VOC2007/SegmentationClass" | ||
pred_dir = "./miou_pr_dir" | ||
png_name_list = open(r"VOCdevkit\VOC2007\ImageSets\Segmentation\val.txt",'r').read().splitlines() | ||
|
||
num_classes = 21 | ||
name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] | ||
compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes) # 执行计算mIoU的函数 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
import math | ||
import os | ||
import torch.utils.model_zoo as model_zoo | ||
BatchNorm2d = nn.BatchNorm2d | ||
|
||
def conv_bn(inp, oup, stride): | ||
return nn.Sequential( | ||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | ||
BatchNorm2d(oup), | ||
nn.ReLU6(inplace=True) | ||
) | ||
|
||
|
||
def conv_1x1_bn(inp, oup): | ||
return nn.Sequential( | ||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | ||
BatchNorm2d(oup), | ||
nn.ReLU6(inplace=True) | ||
) | ||
|
||
|
||
class InvertedResidual(nn.Module): | ||
def __init__(self, inp, oup, stride, expand_ratio): | ||
super(InvertedResidual, self).__init__() | ||
self.stride = stride | ||
assert stride in [1, 2] | ||
|
||
hidden_dim = round(inp * expand_ratio) | ||
self.use_res_connect = self.stride == 1 and inp == oup | ||
|
||
if expand_ratio == 1: | ||
self.conv = nn.Sequential( | ||
# dw | ||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), | ||
BatchNorm2d(hidden_dim), | ||
nn.ReLU6(inplace=True), | ||
# pw-linear | ||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||
BatchNorm2d(oup), | ||
) | ||
else: | ||
self.conv = nn.Sequential( | ||
# pw | ||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), | ||
BatchNorm2d(hidden_dim), | ||
nn.ReLU6(inplace=True), | ||
# dw | ||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), | ||
BatchNorm2d(hidden_dim), | ||
nn.ReLU6(inplace=True), | ||
# pw-linear | ||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||
BatchNorm2d(oup), | ||
) | ||
|
||
def forward(self, x): | ||
if self.use_res_connect: | ||
return x + self.conv(x) | ||
else: | ||
return self.conv(x) | ||
|
||
|
||
class MobileNetV2(nn.Module): | ||
def __init__(self, n_class=1000, input_size=224, width_mult=1.): | ||
super(MobileNetV2, self).__init__() | ||
block = InvertedResidual | ||
input_channel = 32 | ||
last_channel = 1280 | ||
|
||
interverted_residual_setting = [ | ||
# t, c, n, s | ||
# 473,473,3 -> 237,237,32 | ||
# 237,237,32 -> 237,237,16 | ||
[1, 16, 1, 1], | ||
# 237,237,16 -> 119,119,24 | ||
[6, 24, 2, 2], | ||
# 119,119,24 -> 60,60,32 | ||
[6, 32, 3, 2], | ||
# 60,60,32 -> 30,30,64 | ||
[6, 64, 4, 2], | ||
# 30,30,64 -> 30,30,96 | ||
[6, 96, 3, 1], | ||
# 30,30,96 -> 15,15,160 | ||
[6, 160, 3, 2], | ||
# 15,15,160 -> 15,15,320 | ||
[6, 320, 1, 1], | ||
] | ||
|
||
assert input_size % 32 == 0 | ||
# 建立stem层 | ||
input_channel = int(input_channel * width_mult) | ||
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel | ||
|
||
self.features = [conv_bn(3, input_channel, 2)] | ||
|
||
# 根据上述列表进行循环,构建mobilenetv2的结构 | ||
for t, c, n, s in interverted_residual_setting: | ||
output_channel = int(c * width_mult) | ||
for i in range(n): | ||
if i == 0: | ||
self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) | ||
else: | ||
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) | ||
input_channel = output_channel | ||
|
||
# mobilenetv2结构的收尾工作 | ||
self.features.append(conv_1x1_bn(input_channel, self.last_channel)) | ||
self.features = nn.Sequential(*self.features) | ||
|
||
# 最后的分类部分 | ||
self.classifier = nn.Sequential( | ||
nn.Dropout(0.2), | ||
nn.Linear(self.last_channel, n_class), | ||
) | ||
|
||
self._initialize_weights() | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.mean(3).mean(2) | ||
x = self.classifier(x) | ||
return x | ||
|
||
def _initialize_weights(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
if m.bias is not None: | ||
m.bias.data.zero_() | ||
elif isinstance(m, BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
elif isinstance(m, nn.Linear): | ||
n = m.weight.size(1) | ||
m.weight.data.normal_(0, 0.01) | ||
m.bias.data.zero_() | ||
|
||
|
||
def load_url(url, model_dir='./model_data', map_location=None): | ||
if not os.path.exists(model_dir): | ||
os.makedirs(model_dir) | ||
filename = url.split('/')[-1] | ||
cached_file = os.path.join(model_dir, filename) | ||
if os.path.exists(cached_file): | ||
return torch.load(cached_file, map_location=map_location) | ||
else: | ||
return model_zoo.load_url(url,model_dir=model_dir) | ||
|
||
def mobilenetv2(pretrained=False, **kwargs): | ||
model = MobileNetV2(n_class=1000, **kwargs) | ||
if pretrained: | ||
model.load_state_dict(load_url('http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar'), strict=False) | ||
return model |
Oops, something went wrong.