-
Notifications
You must be signed in to change notification settings - Fork 44
/
SS_dataset.py
123 lines (101 loc) · 3.54 KB
/
SS_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
__docformat__ = 'restructedtext en'
__authors__ = ("Alessandro Sordoni")
__contact__ = "Alessandro Sordoni <[email protected]>"
import numpy as np
import os, gc
import cPickle
import copy
import logging
import threading
import Queue
import collections
logger = logging.getLogger(__name__)
class SSFetcher(threading.Thread):
def __init__(self, parent):
threading.Thread.__init__(self)
self.parent = parent
self.indexes = np.arange(parent.data_len)
def run(self):
diter = self.parent
self.parent.rng.shuffle(self.indexes)
offset = 0
# Take groups of 10000 sentences and group by length
while not diter.exit_flag:
last_batch = False
sessions = []
ranks = []
while len(sessions) < diter.batch_size:
if offset == diter.data_len:
if not diter.use_infinite_loop:
last_batch = True
break
else:
# Infinite loop here, we reshuffle the indexes
# and reset the offset
self.parent.rng.shuffle(self.indexes)
offset = 0
index = self.indexes[offset]
s = diter.data[index]
offset += 1
if len(s) > diter.max_len:
continue
# Append tuple if rank file is specified
sessions.append(s)
if diter.has_ranks:
r = diter.rank_data[index]
ranks.append(r)
if len(sessions):
if diter.has_ranks:
diter.queue.put((sessions, ranks))
else:
diter.queue.put(sessions)
if last_batch:
diter.queue.put(None)
return
class SSIterator(object):
def __init__(self,
rng,
batch_size,
session_file=None,
rank_file=None,
dtype="int32",
can_fit=False,
queue_size=100,
cache_size=100,
shuffle=True,
use_infinite_loop=True,
max_len=1000):
args = locals()
args.pop("self")
self.__dict__.update(args)
self.has_ranks = rank_file is not None
self.load_files()
self.exit_flag = False
def load_files(self):
self.data = cPickle.load(open(self.session_file, 'r'))
self.data_len = len(self.data)
logger.debug('Data len is %d' % self.data_len)
if self.has_ranks:
self.rank_data = cPickle.load(open(self.rank_file, 'r'))
self.rank_data_len = len(self.rank_data)
assert self.rank_data_len == self.data_len
logger.debug('Rank data len is %d' % self.rank_data_len)
def start(self):
self.exit_flag = False
self.queue = Queue.Queue(maxsize=self.queue_size)
self.gather = SSFetcher(self)
self.gather.daemon = True
self.gather.start()
def __del__(self):
if hasattr(self, 'gather'):
self.gather.exitFlag = True
self.gather.join()
def __iter__(self):
return self
def next(self):
if self.exit_flag:
return None
batch = self.queue.get()
if not batch:
self.exit_flag = True
return batch