-
Notifications
You must be signed in to change notification settings - Fork 0
/
datamodule.py
48 lines (38 loc) · 2.37 KB
/
datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
class data_module(pl.LightningDataModule):
def __init__(self, data_dir: str = "./stylized_resized/", batch_size=128): # in the Paper batch = 1
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
#TODO specific data normalization
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]) #pytorch mean and std of imagenet
'''transforms.RandomHorizontalFlip(),
transforms.RandomApply(torch.nn.ModuleList([
transforms.RandomRotation((90, 90)),
transforms.RandomRotation((180, 180)),
transforms.RandomRotation((270, 270))]), p=0.25),
#transforms.RandomErasing(p=0.1, scale=(0.04, 0.05), ratio=(0.3, 3.3), value=0.3, inplace=False),
transforms.RandomApply(torch.nn.ModuleList([
transforms.ColorJitter(brightness=(0.4, 1.5)),
transforms.ColorJitter(contrast=(0.6, 1.7)),
transforms.ColorJitter(brightness=(0.4, 1.5), contrast=(0.6, 1.7))
]), p=0.25)'''
self.transform_test = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
self.train_dataset = datasets.ImageFolder(self.data_dir+"train/", transform=self.transform)
self.val_dataset = datasets.ImageFolder(self.data_dir+"val/", transform=self.transform)
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.test_dataset = datasets.ImageFolder(self.data_dir+"test/", transform=self.transform_test)
def train_dataloader(self):
return DataLoader(self.train_dataset, shuffle=True, batch_size=self.batch_size,num_workers=16) ###CHANGE num_workers
def validation_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=16) ##CHANGE num_workers
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=16)