-
Notifications
You must be signed in to change notification settings - Fork 33
/
data_loader.py
31 lines (22 loc) · 910 Bytes
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch.utils.data import DataLoader
import h5py
class VideoData(object):
"""Dataset class"""
def __init__(self, data_path):
self.data_file = h5py.File(data_path)
def __len__(self):
return len(self.data_file)
def __getitem__(self, index):
index += 1
video = self.data_file['video_'+str(index)]
feature = torch.tensor(video['feature'][()]).t()
label = torch.tensor(video['label'][()], dtype=torch.long)
return feature, label, index
def get_loader(path, batch_size=5):
dataset = VideoData(path)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - len(dataset) // 5, len(dataset) // 5])
train_loader = DataLoader(train_dataset, batch_size=batch_size)
return train_loader, test_dataset
if __name__ == '__main__':
loader = get_loader('fcsn_dataset.h5')