-
Notifications
You must be signed in to change notification settings - Fork 0
/
getDatas.py
145 lines (119 loc) · 4.29 KB
/
getDatas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import numpy as np
from PIL import Image
import os
import torchvision as tv
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt
# 训练集预处理
preprocessTrain = transforms.Compose([
# transforms.Resize([230, 230]),
transforms.Resize([224, 224]),
transforms.RandomHorizontalFlip(p=0.5),
# transforms.RandomVerticalFlip(p=0.5),
# transforms.RandomRotation(15),
# 功能:修改修改亮度、对比度和饱和度
transforms.ColorJitter(brightness=0.5, contrast=0.5,hue=0.5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 测试验证集预处理
preprocessVal_Test = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 从文件中读取数据
def defaultLoader(path,ifTrain):
img_pil = Image.open(path).convert('RGB')
# print(len(img_pil))
if ifTrain==True:
img_tensor = preprocessTrain(img_pil)
else:
img_tensor = preprocessVal_Test(img_pil)
return img_tensor
# 从txt中获取类型
def getClassFromTxt(txtPath):
index_classes={}
with open(txtPath, 'r', encoding='utf-8') as f:
txts=f.readlines()
txts=[t[:-1] for t in txts]
for t in txts:
label,name=t.split(' ')[1].split('.')
index_classes[int(label)-1]=name
return index_classes
# 获取图片的路径
def getImgPath(imgPath):
imgs=[]
for imgName in os.listdir(imgPath):
imgs.append(imgName)
return imgs
# 判断样本是否平衡,传入dataloader
def ifBalance(train_loader):
class_num_list=[0 for i in range(200)]
for i, (data, labels) in enumerate(train_loader):
for label in labels:
class_num_list[label]+=1
print(class_num_list)
# dataset类
class birdTrainDataSet(Dataset):
def __init__(self,imgTrainPath,txtClassPath, ifTrain):
self.class_num_list=[0 for i in range(200)]
self.txtClassPath=txtClassPath
self.imgTrainPath=imgTrainPath
self.ifTrain=ifTrain
self.index_classes=getClassFromTxt(txtClassPath)
# print(self.index_classes)
self.imgNames=getImgPath(imgTrainPath)
# 统计类的种类
self.getClassNum()
def __getitem__(self, index):
imgName=self.imgNames[index]
num,name=imgName.split('.',1)
label=int(num)
img=defaultLoader(os.path.join(self.imgTrainPath,imgName),self.ifTrain)
label=label-1
return img,label
def __len__(self):
return len(self.imgNames)
def getClassNum(self):
for imgname in self.imgNames:
num,name=imgname.split('.',1)
label=int(num)-1
self.class_num_list[label]+=1
self.class_num_list=1/np.array(self.class_num_list)
# print(','.join([str(s) for s in self.class_num_list]))
# 不确定:根据torch种平衡样本的语法,应该取倒数
# self.class_num_list=1/np.array(self.class_num_list)
if __name__ == '__main__' :
imgTrainPath='../data/bird/train_set'
txtClassPath='../data/bird/classes.txt'
train_dataset=birdTrainDataSet(imgTrainPath,txtClassPath,True)
# 测试可知样本不均衡
# print(train_dataset.class_num_list)
# show=ToPILImage()
# (data, label) = bd[100]
# print(label)
# data=show((data+1)/2)
# # print(type(data))
# # print(data)
# plt.imshow(data)
# plt.title('image') # 图像题目
# plt.show()
# 不准
weights=[]
for data, label in train_dataset:
weights.append(train_dataset.class_num_list[label])
print(','.join([str(s) for s in weights]))
batch_size=64
# 注意这里的weights应为所有样本的权重序列,其长度为所有样本长度
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(train_dataset),replacement=True)
trainloader = DataLoader(train_dataset, batch_size = batch_size, sampler = sampler)
# trainloader = DataLoader(train_dataset, batch_size = batch_size)
ifBalance(trainloader)
# iterloader=iter(trainloader)
# images,label=iterloader.next()
# print(images.size())
# print(label)