-
Notifications
You must be signed in to change notification settings - Fork 25
/
dataloader_cifar.py
29 lines (27 loc) · 1.03 KB
/
dataloader_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.utils.data.distributed import DistributedSampler
def load_data(batchsize:int, numworkers:int) -> tuple[DataLoader, DistributedSampler]:
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
data_train = CIFAR10(
root = './',
train = True,
download = False,
transform = trans
)
sampler = DistributedSampler(data_train)
trainloader = DataLoader(
data_train,
batch_size = batchsize,
num_workers = numworkers,
sampler = sampler,
drop_last = True
)
return trainloader, sampler
def transback(data:Tensor) -> Tensor:
return data / 2 + 0.5