forked from HIPS/autograd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
67 lines (56 loc) · 2.7 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import absolute_import
import matplotlib.pyplot as plt
import matplotlib.image
import autograd.numpy as np
import autograd.numpy.random as npr
import data_mnist
def load_mnist():
partial_flatten = lambda x : np.reshape(x, (x.shape[0], np.prod(x.shape[1:])))
one_hot = lambda x, k: np.array(x[:,None] == np.arange(k)[None, :], dtype=int)
train_images, train_labels, test_images, test_labels = data_mnist.mnist()
train_images = partial_flatten(train_images) / 255.0
test_images = partial_flatten(test_images) / 255.0
train_labels = one_hot(train_labels, 10)
test_labels = one_hot(test_labels, 10)
N_data = train_images.shape[0]
return N_data, train_images, train_labels, test_images, test_labels
def plot_images(images, ax, ims_per_row=5, padding=5, digit_dimensions=(28, 28),
cmap=matplotlib.cm.binary, vmin=None, vmax=None):
"""Images should be a (N_images x pixels) matrix."""
N_images = images.shape[0]
N_rows = (N_images - 1) // ims_per_row + 1
pad_value = np.min(images.ravel())
concat_images = np.full(((digit_dimensions[0] + padding) * N_rows + padding,
(digit_dimensions[1] + padding) * ims_per_row + padding), pad_value)
for i in range(N_images):
cur_image = np.reshape(images[i, :], digit_dimensions)
row_ix = i // ims_per_row
col_ix = i % ims_per_row
row_start = padding + (padding + digit_dimensions[0]) * row_ix
col_start = padding + (padding + digit_dimensions[1]) * col_ix
concat_images[row_start: row_start + digit_dimensions[0],
col_start: col_start + digit_dimensions[1]] = cur_image
cax = ax.matshow(concat_images, cmap=cmap, vmin=vmin, vmax=vmax)
plt.xticks(np.array([]))
plt.yticks(np.array([]))
return cax
def save_images(images, filename, **kwargs):
fig = plt.figure(1)
fig.clf()
ax = fig.add_subplot(111)
plot_images(images, ax, **kwargs)
fig.patch.set_visible(False)
ax.patch.set_visible(False)
plt.savefig(filename)
def make_pinwheel(radial_std, tangential_std, num_classes, num_per_class, rate,
rs=npr.RandomState(0)):
"""Based on code by Ryan P. Adams."""
rads = np.linspace(0, 2*np.pi, num_classes, endpoint=False)
features = rs.randn(num_classes*num_per_class, 2) \
* np.array([radial_std, tangential_std])
features[:, 0] += 1
labels = np.repeat(np.arange(num_classes), num_per_class)
angles = rads[labels] + rate * np.exp(features[:,0])
rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
rotations = np.reshape(rotations.T, (-1, 2, 2))
return np.einsum('ti,tij->tj', features, rotations)