-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
29 lines (22 loc) · 832 Bytes
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""Dataset interface for pre-TF r1.4."""
import numpy as np
class Dataset:
def __init__(self, x, y, validation_split=1.):
"""Initialize with the data and hyperparameters."""
length = y.shape[0]
# Split the data.
split_ind = int(length * validation_split)
self.x = {'train': x[:, split_ind:, :], 'val': x[:, :split_ind, :]}
self.y = {'train': y[split_ind:, :], 'val': y[:split_ind, :]}
def shuffle(self):
"""Shuffle the training data."""
self.x['train'], self.y['train'] = self._shuffle(
self.x['train'],
self.y['train']
)
@staticmethod
def _shuffle(x, y):
"""Shuffle the given data and return them."""
ind = np.arange(y.shape[0])
np.random.shuffle(ind)
return x[:, ind, :], y[ind]