-
Notifications
You must be signed in to change notification settings - Fork 1
/
ms_marco_data_reader.py
69 lines (52 loc) · 2.71 KB
/
ms_marco_data_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pickle
import codecs
import torch
from torch.utils.data import Dataset, IterableDataset
class MSMarcoDataset(Dataset):
def __init__(self, cached_features_file, subset):
self._subset = subset
self._is_training = (subset == 'train')
self._features = torch.load(cached_features_file)
self._size = len(self._features)
def get_features(self):
return self._features
def __len__(self):
return self._size
def __getitem__(self, index):
feature = self._features[index]
index = torch.tensor(index, dtype=torch.long)
input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
input_mask = torch.tensor(feature.input_mask, dtype=torch.long)
segment_ids = torch.tensor(feature.segment_ids, dtype=torch.long)
outputs = (input_ids, input_mask, segment_ids, index)
if self._is_training:
is_selected = torch.tensor(feature.is_selected, dtype=torch.float)
start_position = torch.tensor(feature.start_position if feature.start_position else 0, dtype=torch.long)
end_posittion = torch.tensor(feature.end_position if feature.end_position else 0, dtype=torch.long)
is_impossible = torch.tensor(feature.is_impossible, dtype=torch.long)
outputs = (input_ids, input_mask, segment_ids, index, is_selected, start_position, end_posittion, is_impossible)
return outputs
class MSMarcoIterDataset(IterableDataset):
def __init__(self, cached_features_file, subset):
self._subset = subset
self._file = cached_features_file
self._is_training = (subset == 'train')
def __iter__(self):
file_itr = open(self._file)
mapped_itr = map(self.line_mapper, file_itr)
return mapped_itr
def line_mapper(self, line):
s = line.strip()
feature = pickle.loads(codecs.decode(s.encode(), "base64"))
index = torch.tensor(index, dtype=torch.long)
input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
input_mask = torch.tensor(feature.input_mask, dtype=torch.long)
segment_ids = torch.tensor(feature.segment_ids, dtype=torch.long)
outputs = (input_ids, input_mask, segment_ids, index)
if self._is_training:
is_selected = torch.tensor(feature.is_selected, dtype=torch.float)
start_position = torch.tensor(feature.start_position, dtype=torch.long)
end_posittion = torch.tensor(feature.end_position, dtype=torch.long)
is_impossible = torch.tensor(feature.is_impossible, dtype=torch.long)
outputs = (input_ids, input_mask, segment_ids, index, is_selected, start_position, end_posittion, is_impossible)
return outputs