-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
91 lines (78 loc) · 3.37 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Python packages
from termcolor import colored
from tqdm import tqdm
import os
import tarfile
import wget
# PyTorch & Pytorch Lightning
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
# Custom packages
import src.config as cfg
class TinyImageNetDatasetModule(LightningDataModule):
__DATASET_NAME__ = 'tiny-imagenet-200'
def __init__(self, batch_size: int = cfg.BATCH_SIZE):
super().__init__()
self.batch_size = batch_size
def prepare_data(self):
'''called only once and on 1 GPU'''
if not os.path.exists(os.path.join(cfg.DATASET_ROOT_PATH, self.__DATASET_NAME__)):
# download data
print(colored("\nDownloading dataset...", color='green', attrs=('bold',)))
filename = self.__DATASET_NAME__ + '.tar'
wget.download(f'https://hyu-aue8088.s3.ap-northeast-2.amazonaws.com/{filename}')
# extract data
print(colored("\nExtract dataset...", color='green', attrs=('bold',)))
with tarfile.open(name=filename) as tar:
# Go over each member
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())):
# Extract member
tar.extract(path=cfg.DATASET_ROOT_PATH, member=member)
os.remove(filename)
def train_dataloader(self):
tf_train = transforms.Compose([
transforms.RandomRotation(cfg.IMAGE_ROTATION),
transforms.RandomHorizontalFlip(cfg.IMAGE_FLIP_PROB),
transforms.RandomCrop(cfg.IMAGE_NUM_CROPS, padding=cfg.IMAGE_PAD_CROPS),
transforms.ToTensor(),
transforms.Normalize(cfg.IMAGE_MEAN, cfg.IMAGE_STD),
])
dataset = ImageFolder(os.path.join(cfg.DATASET_ROOT_PATH, self.__DATASET_NAME__, 'train'), tf_train)
msg = f"[Train]\t root dir: {dataset.root}\t | # of samples: {len(dataset):,}"
print(colored(msg, color='blue', attrs=('bold',)))
return DataLoader(
dataset,
shuffle=True,
pin_memory=True,
num_workers=cfg.NUM_WORKERS,
batch_size=self.batch_size,
)
def val_dataloader(self):
tf_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cfg.IMAGE_MEAN, cfg.IMAGE_STD),
])
dataset = ImageFolder(os.path.join(cfg.DATASET_ROOT_PATH, self.__DATASET_NAME__, 'val'), tf_val)
msg = f"[Val]\t root dir: {dataset.root}\t | # of samples: {len(dataset):,}"
print(colored(msg, color='blue', attrs=('bold',)))
return DataLoader(
dataset,
pin_memory=True,
num_workers=cfg.NUM_WORKERS,
batch_size=self.batch_size,
)
def test_dataloader(self):
tf_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cfg.IMAGE_MEAN, cfg.IMAGE_STD),
])
dataset = ImageFolder(os.path.join(cfg.DATASET_ROOT_PATH, self.__DATASET_NAME__, 'test'), tf_test)
msg = f"[Test]\t root dir: {dataset.root}\t | # of samples: {len(dataset):,}"
print(colored(msg, color='blue', attrs=('bold',)))
return DataLoader(
dataset,
num_workers=cfg.NUM_WORKERS,
batch_size=self.batch_size,
)