-
Notifications
You must be signed in to change notification settings - Fork 21
/
clustering.py
68 lines (60 loc) · 2.18 KB
/
clustering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
num_vectors = 1000
num_clusters = 3
num_steps = 100
vector_values = []
for i in range(num_vectors):
if np.random.random() > 0.5:
vector_values.append([np.random.normal(0.5, 0.6),
np.random.normal(0.3, 0.9)])
else:
vector_values.append([np.random.normal(2.5, 0.4),
np.random.normal(0.8, 0.5)])
df = pd.DataFrame({"x": [v[0] for v in vector_values],
"y": [v[1] for v in vector_values]})
sns.lmplot("x", "y", data=df, fit_reg=False, size=7)
plt.show()
vectors = tf.constant(vector_values)
centroids = tf.Variable(tf.slice(tf.random_shuffle(vectors),
[0, 0], [num_clusters, -1]))
expanded_vectors = tf.expand_dims(vectors, 0)
expanded_centroids = tf.expand_dims(centroids, 1)
print(expanded_vectors.get_shape())
print(expanded_centroids.get_shape())
distances = tf.reduce_sum(
tf.square(tf.subtract(expanded_vectors, expanded_centroids)), 2)
assignments = tf.argmin(distances, 0)
means = tf.concat(axis=0, values=[
tf.reduce_mean(
tf.gather(vectors,
tf.reshape(
tf.where(
tf.equal(assignments, c)
), [1, -1])
), axis=[1])
for c in range(num_clusters)])
update_centroids = tf.assign(centroids, means)
init_op = tf.global_variables_initializer()
# with tf.Session('local') as sess:
sess = tf.Session()
sess.run(init_op)
for step in range(num_steps):
_, centroid_values, assignment_values = sess.run([update_centroids,
centroids,
assignments])
print("centroids")
print(centroid_values)
data = {"x": [], "y": [], "cluster": []}
for i in range(len(assignment_values)):
data["x"].append(vector_values[i][0])
data["y"].append(vector_values[i][1])
data["cluster"].append(assignment_values[i])
df = pd.DataFrame(data)
sns.lmplot("x", "y", data=df,
fit_reg=False, size=7,
hue="cluster", legend=False)
plt.show()