-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataloader.py
105 lines (86 loc) · 3.51 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from torchvision import transforms
import torch.utils.data as data
from PIL import Image
import os
import os.path
import numpy as np
from os import listdir
from os.path import join
import random
import math
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def default_loader(path):
return Image.open(path).convert('RGB')
def ToTensor(pic):
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backard compability
return img.float().div(255)
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
# You should build custom dataset as below.
class CelebA(data.Dataset):
def __init__(self,dataPath='data/CelebA/images/',loadSize=64,fineSize=64,flip=1):
super(CelebA, self).__init__()
# list all images into a list
self.image_list = [x for x in listdir(dataPath) if is_image_file(x)]
self.dataPath = dataPath
self.loadSize = loadSize
self.fineSize = fineSize
self.flip = flip
def __getitem__(self, index):
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
path = os.path.join(self.dataPath,self.image_list[index])
img = default_loader(path)
w,h = img.size
if(h != self.loadSize):
img = img.resize((self.loadSize, self.loadSize), Image.BILINEAR)
if(self.loadSize != self.fineSize):
#x1 = random.randint(0, self.loadSize - self.fineSize)
#y1 = random.randint(0, self.loadSize - self.fineSize)
x1 = math.floor((self.loadSize - self.fineSize)/2)
y1 = math.floor((self.loadSize - self.fineSize)/2)
img = img.crop((x1, y1, x1 + self.fineSize, y1 + self.fineSize))
if(self.flip == 1):
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = ToTensor(img) # 3 x 256 x 256
img = img.mul_(2).add_(-1)
# 3. Return a data pair (e.g. image and label).
return img
def __len__(self):
# You should change 0 to the total size of your dataset.
return len(self.image_list)
def get_loader(root, batch_size, scale_size, num_workers=12, shuffle=True):
dataset = CelebA(root,scale_size,scale_size,1)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader