forked from facebookresearch/mae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
flylight_dataset.py
51 lines (40 loc) · 1.37 KB
/
flylight_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np
import zarr
from glob import glob
import torch
from torch.utils.data import Dataset
class FlylightDataset(Dataset):
def __init__(self, root_dir, input_size, transform=None):
self.root_dir = root_dir
self.input_size = input_size
self.transform = transform
self.samples = glob(self.root_dir + "/*/*/*.zarr")
self.num_samples = len(self.samples)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# load zarr
while True:
sample = self.samples[idx]
try:
zraw = zarr.open(sample, "r", path="volumes/raw")
break
except:
idx = (idx + 1) % self.num_samples
# get random crop
d, h, w = zraw.shape[-3:]
new_d = new_h = new_w = self.input_size
d_start = np.random.randint(0, d - new_d + 1)
h_start = np.random.randint(0, h - new_h + 1)
w_start = np.random.randint(0, w - new_w + 1)
raw = np.array(zraw[...,
d_start:d_start + new_d,
h_start:h_start + new_h,
w_start:w_start + new_w])
# normalize
raw = np.clip(raw, 0, 1500)
raw = (raw / 1500.).astype(np.float32)
raw = torch.from_numpy(raw)
# if self.transform:
# raw = self.transform(raw)
return raw