-
Notifications
You must be signed in to change notification settings - Fork 230
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
598 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
10 10 | ||
0 2500 3300 1 7400 8200 2 900 900 3 3200 4000 4 4900 4900 5 1100 1100 6 5900 6500 7 5600 5600 8 3900 4900 9 2000 2200 | ||
0 3900 4700 2 8900 9100 4 6700 8300 9 1100 1100 3 6300 7500 1 2500 3100 6 4600 4600 5 4300 4900 7 7100 7300 8 2700 3300 | ||
1 8600 9600 0 7700 9300 3 3400 4400 2 6400 8400 8 8900 9100 5 1000 1000 7 1100 1300 6 8700 9100 9 4100 4900 4 3100 3500 | ||
1 7900 8300 2 9400 9600 0 6200 8000 4 9200 10600 6 900 900 8 4600 5800 7 7400 9600 3 9000 10600 9 2100 2300 5 4100 4500 | ||
2 1300 1500 0 600 600 1 2200 2200 5 5300 6900 3 2500 2700 4 6400 7400 8 2000 2200 7 4600 5200 9 7200 7200 6 4600 6000 | ||
2 7500 9300 1 200 200 5 5200 5200 3 8900 10100 8 4300 5300 9 7000 7400 0 4400 5000 6 6200 6800 4 600 600 7 2300 2700 | ||
1 4300 4900 0 3400 4000 3 6000 6200 2 1300 1300 6 2800 3600 5 2000 2200 9 2900 3500 8 8300 9500 7 2900 3100 4 5400 5600 | ||
2 2900 3300 0 8600 8600 1 4600 4600 5 6500 8300 4 2900 3500 6 7900 9700 8 1700 2100 9 4500 5100 7 3500 3700 3 6900 8900 | ||
0 7200 8000 1 6300 7500 3 6700 8500 5 5100 5100 2 7700 9300 9 1000 1200 6 3600 4400 7 8300 9500 4 2600 2600 8 7300 7500 | ||
1 8100 8900 0 1300 1300 2 5800 6400 6 700 700 8 5900 6900 9 7000 8200 5 4600 4800 3 5000 5400 4 8200 9800 7 4100 4900 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# -------------------------------------------------------------------------- | ||
# Source file provided under Apache License, Version 2.0, January 2004, | ||
# http://www.apache.org/licenses/ | ||
# (c) Copyright IBM Corp. 2021, 2022 | ||
# -------------------------------------------------------------------------- | ||
|
||
""" | ||
K-means is a way of clustering points in a multi-dimensional space | ||
where the set of points to be clustered are partitioned into k subsets. | ||
The idea is to minimize the inter-point distances inside a cluster in | ||
order to produce clusters which group together close points. | ||
See https://en.wikipedia.org/wiki/K-means_clustering | ||
""" | ||
|
||
|
||
import numpy as np | ||
from docplex.cp.model import CpoModel | ||
import docplex.cp.solver.solver as solver | ||
from docplex.cp.utils import compare_natural | ||
|
||
def make_model(coords, k, trust_numerics=True): | ||
""" | ||
Build a K-means model from a set of coordinate vectors (points), | ||
and a given number of clusters k. | ||
We assign each point to a cluster and minimize the objective which | ||
is the sum of the squares of the distances of each point to | ||
the centre of gravity of the cluster to which it belongs. | ||
Here, there are two ways of building the objective function. One | ||
uses the sum of squares of the coordinates of points in a cluster | ||
minus the size of the cluster times the center value. This is akin | ||
to the calculation of variance vi E[X^2] - E[X]^2. This is the most | ||
efficient but can be numerically unstable due to massive cancellation. | ||
The more numerically stable (but less efficient) way to calculate the | ||
objective is the analog of the variance calculation (sum_i(X_i - mu_i)^2)/n | ||
""" | ||
# Sizes and ranges | ||
n, d = coords.shape | ||
N, D, K = range(n), range(d), range(k) | ||
|
||
# Model, and decision variables. x[c] = cluster to which node c belongs | ||
mdl = CpoModel() | ||
x = [mdl.integer_var(0, k-1, "C_{}".format(i)) for i in N] | ||
|
||
# Size (number of nodes) in each cluster. If this is zero, we make | ||
# it 1 to avoid division by zero later (if a particular cluster is | ||
# not used). | ||
csize = [mdl.max(1, mdl.count(x, c)) for c in K] | ||
|
||
# Calculate total distance squared | ||
total_dist2 = 0 | ||
for c in K: # For each cluster | ||
# Boolean vector saying which points are in this cluster | ||
included = [x[i] == c for i in N] | ||
for dim in D: # For each dimension | ||
# Points for each point in the given dimension (x, y, z, ...) | ||
point = coords[:, dim] | ||
|
||
# Calculate the cluster centre for this dimension | ||
centre = mdl.scal_prod(included, point) / csize[c] | ||
|
||
# Calculate the total distance^2 for this cluster & dimension | ||
if trust_numerics: | ||
sum_of_x2 = mdl.scal_prod(included, (p**2 for p in point)) | ||
dist2 = sum_of_x2 - centre**2 * csize[c] | ||
else: | ||
all_dist2 = ((centre - p)**2 for p in point) | ||
dist2 = mdl.scal_prod(included, all_dist2) | ||
|
||
# Keep the total distance squared in a sum | ||
total_dist2 += dist2 | ||
|
||
# Minimize the total distance squared | ||
mdl.minimize(total_dist2) | ||
return mdl | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
# Default values | ||
n, d, k, sd = 500, 2, 5, 1234 | ||
|
||
# Accept number of points, number of dimensions, number of clusters, seed | ||
if len(sys.argv) > 1: | ||
n = int(sys.argv[1]) | ||
if len(sys.argv) > 2: | ||
d = int(sys.argv[2]) | ||
if len(sys.argv) > 3: | ||
k = int(sys.argv[3]) | ||
if len(sys.argv) > 4: | ||
sd = int(sys.argv[4]) | ||
|
||
# Message | ||
print("Generating with N = {}, D = {}, K = {}".format(n, d, k)) | ||
|
||
# Seed and generate coordinates on the unit hypercube | ||
np.random.seed(sd) | ||
coords = np.random.uniform(0, 1, size=(n, d)) | ||
|
||
# Build model | ||
mdl = make_model(coords, k) | ||
|
||
# Solve using constraint programming | ||
mdl.solve(SearchType="Restart", TimeLimit=10, LogPeriod=50000) | ||
|
||
if compare_natural(solver.get_solver_version(), '22.1') >= 0: | ||
# Solve using neighborhood search | ||
mdl.solve(SearchType="Neighborhood", TimeLimit=10, LogPeriod=50000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.