-
-
Notifications
You must be signed in to change notification settings - Fork 49
/
data.py
55 lines (50 loc) · 1.89 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
from PIL import Image
import os
class ImageData(Dataset):
def __init__(self, root_path="CACD2000/", label_path="data/label.npy", name_path="data/name.npy", train_mode = "train"):
"""
Initialize some variables
Load labels & names
define transform
"""
self.root_path = root_path
self.image_labels = np.load(label_path)
self.image_names = np.load(name_path)
self.train_mode = train_mode
self.transform = {
'train': transforms.Compose([
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize([0.656,0.487,0.411], [1., 1., 1.])
]),
'val': transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
# transforms.Normalize([0.656,0.487,0.411], [1., 1., 1.])
]),
}
def __len__(self):
"""
Get the length of the entire dataset
"""
print("Length of dataset is ", self.image_labels.shape[0])
return self.image_labels.shape[0]
def __getitem__(self, idx):
"""
Get the image item by index
"""
image_name = os.path.join(self.root_path, self.image_names[idx])
image = Image.open(image_name)
image_label = self.image_labels[idx].astype(int) - 1
transformed_img = self.transform[self.train_mode](image)
sample = {'image':transformed_img, 'label':torch.from_numpy(image_label)}
return sample