-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_util.py
39 lines (34 loc) · 1.34 KB
/
mnist_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import struct
import numpy as np
def mnist_read_images(filename):
with open(filename, 'rb') as f:
bytes = f.read()
magic, count, rows, cols = struct.unpack_from('>iiii', bytes)
if magic != 2051:
raise Exception('failed to read mnist images from {}'.format(filename))
images = np.frombuffer(bytes, dtype='uint8', offset=16)
images = images.reshape(count, rows, cols)
return images
def mnist_one_hot_encode(labels):
max_label = max(labels)
one_hot_encoding = []
for label in labels:
encoding = [0] * (max_label + 1)
encoding[label] = 1
one_hot_encoding.append(encoding)
return np.array(one_hot_encoding).reshape(len(labels), max_label + 1)
def mnist_read_labels(filename, one_hot_encoding=False):
with open(filename, 'rb') as f:
bytes = f.read()
magic, count = struct.unpack_from('>ii', bytes)
if magic != 2049:
raise Exception('failed to read mnist labels from {}'.format(filename))
labels = np.frombuffer(bytes, dtype='uint8', offset=8)
if one_hot_encoding:
return mnist_one_hot_encode(labels)
else:
return labels
def mnist_read(images_filename, labels_filename, one_hot_encoding=False):
images = mnist_read_images(images_filename)
labels = mnist_read_labels(labels_filename, one_hot_encoding)
return images, labels