diff --git a/.idea/runConfigurations/TrinityMainMaven.xml b/.idea/runConfigurations/TrinityMainMaven.xml
index 76e1dcae..8919387b 100644
--- a/.idea/runConfigurations/TrinityMainMaven.xml
+++ b/.idea/runConfigurations/TrinityMainMaven.xml
@@ -12,4 +12,4 @@
-
\ No newline at end of file
+
diff --git a/build.gradle b/build.gradle
index 1d7bf680..4b92ed4b 100644
--- a/build.gradle
+++ b/build.gradle
@@ -61,6 +61,7 @@ dependencies {
}
implementation group: 'com.github.quickhull3d', name: 'quickhull3d', version: '1.0.0'
implementation group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1'
+ implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
implementation group: 'org.zeromq', name: 'jeromq', version: '0.6.0'
implementation group: 'com.fasterxml.jackson.core', name: 'jackson-annotations', version: '2.17.0'
implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.17.0'
@@ -71,6 +72,7 @@ dependencies {
implementation group: 'org.slf4j', name: 'slf4j-api', version: '2.0.12'
testImplementation 'org.junit.jupiter:junit-jupiter:5.10.2'
+ testImplementation 'org.junit.platform:junit-platform-suite:1.10.2'
testRuntimeOnly 'org.junit.platform:junit-platform-launcher:1.10.2'
}
diff --git a/gradle.properties b/gradle.properties
index 5a0bf3d2..597cb029 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -8,5 +8,5 @@ org.gradle.java.installations.auto-detect=true
# Project Build Properties
env=dev
-javafx.version=21.0.2
+javafx.version=21.0.3
javafx.static.version=21-ea+11.2
diff --git a/nbactions-assembly.xml b/nbactions-assembly.xml
new file mode 100644
index 00000000..dd0041ae
--- /dev/null
+++ b/nbactions-assembly.xml
@@ -0,0 +1,37 @@
+
+
+
+ run
+
+ jar
+
+
+ clean
+ javafx:run
+
+
+ java
+ -Dprism.maxvram=2G
+ ${exec.vmArgs}
+ edu.jhuapl.trinity.TrinityMain
+
+
+
+ debug
+
+ jar
+
+
+ clean
+ javafx:run@debug
+
+
+ true
+ java
+ true
+ 8000
+ -Dprism.maxvram=2G
+ ${exec.vmArgs} -Xdebug -Xrunjdwp:transport=dt_socket,server=n,address=8000 -classpath %classpath edu.jhuapl.trinity.App
+
+
+
diff --git a/pom.xml b/pom.xml
index 874e81b4..2a6560fa 100644
--- a/pom.xml
+++ b/pom.xml
@@ -20,7 +20,7 @@
UTF-8UTF-821
- 21.0.2
+ 21.0.321-ea+11.23.6.3${java.version}
@@ -40,19 +40,18 @@
8.0.29.0.105.10.2
+ 1.10.20.1.30.6.021.0.71.0.03.6.1
+ 3.14.00.6.02.17.00.3.121.5.22.0.12
- The Johns Hopkins University Applied Physics Laboratory LLC
- LICENSE.md
- apache_v2edu.jhuapl.trinity.TrinityMainTrinity
@@ -96,14 +95,20 @@
org.junit.jupiter
- junit-jupiter-api
+ junit-jupiter${junit.jupiter.version}test
- org.junit.jupiter
- junit-jupiter-params
- ${junit.jupiter.version}
+ org.junit.platform
+ junit-platform-suite
+ ${junit.jupiter.platform.version}
+ test
+
+
+ org.junit.platform
+ junit-platform-launcher
+ ${junit.jupiter.platform.version}test
@@ -131,6 +136,11 @@
commons-math3${apache.commons.math3.version}
+
+ org.apache.commons
+ commons-lang3
+ ${apache.commons.lang3.version}
+ org.zeromqjeromq
@@ -266,33 +276,6 @@
${mainClassName}
-
-
-
- org.codehaus.mojo
- license-maven-plugin
- ${codehaus.license.plugin.version}
-
- apache_v2
- false
- false
-
- src/main/java
- src/test/java
-
-
- **/*.json
-
-
-
-
- process-sources
-
- update-file-header
-
-
-
- io.github.git-commit-idgit-commit-id-maven-plugin
diff --git a/src/main/java/com/clust4j/Clust4j.java b/src/main/java/com/clust4j/Clust4j.java
new file mode 100644
index 00000000..a7ebd778
--- /dev/null
+++ b/src/main/java/com/clust4j/Clust4j.java
@@ -0,0 +1,77 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j;
+
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+
+/**
+ * The absolute super type for all clust4j objects (models and datasets)
+ * that should be able to commonly serialize their data.
+ *
+ * @author Taylor G Smith
+ */
+public abstract class Clust4j implements java.io.Serializable {
+ private static final long serialVersionUID = -4522135376738501625L;
+
+ /**
+ * Load a model from a FileInputStream
+ *
+ * @param fos
+ * @return
+ * @throws IOException
+ * @throws ClassNotFoundException
+ */
+ public static Clust4j loadObject(final FileInputStream fis) throws IOException, ClassNotFoundException {
+ ObjectInputStream in = null;
+ Clust4j bm = null;
+
+ try {
+ in = new ObjectInputStream(fis);
+ bm = (Clust4j) in.readObject();
+ } finally {
+ if (null != in)
+ in.close();
+
+ fis.close();
+ }
+
+ return bm;
+ }
+
+ /**
+ * Save a model to FileOutputStream
+ *
+ * @param fos
+ * @throws IOException
+ */
+ public void saveObject(final FileOutputStream fos) throws IOException {
+ ObjectOutputStream out = null;
+
+ try {
+ out = new ObjectOutputStream(fos);
+ out.writeObject(this);
+ } finally {
+ if (null != out)
+ out.close();
+
+ fos.close();
+ }
+ }
+}
diff --git a/src/main/java/com/clust4j/GlobalState.java b/src/main/java/com/clust4j/GlobalState.java
new file mode 100644
index 00000000..85e6152b
--- /dev/null
+++ b/src/main/java/com/clust4j/GlobalState.java
@@ -0,0 +1,259 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j;
+
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.Random;
+import java.util.concurrent.ForkJoinPool;
+
+/**
+ * A set of global config values used in multiple classes. Some values may
+ * be set to the user's preference, while others are final.
+ *
+ * @author Taylor G Smith
+ */
+public abstract class GlobalState {
+ /**
+ * The default random state
+ */
+ public final static Random DEFAULT_RANDOM_STATE = new Random(999);
+ public final static int MAX_ARRAY_SIZE = 25_000_000;
+
+
+ /**
+ * Holds static mathematical values
+ *
+ * @author Taylor G Smith
+ */
+ public static abstract class Mathematics {
+ /**
+ * Double.MIN_VALUE is not negative; this is
+ */
+ public final static double SIGNED_MIN = Double.NEGATIVE_INFINITY;
+ public final static double MAX = Double.POSITIVE_INFINITY;
+ public final static double TINY = 2.2250738585072014e-308;
+ public final static double EPS = 2.2204460492503131e-16;
+
+ /*===== Gamma function assistants =====*/
+ public final static double LOG_PI = FastMath.log(Math.PI);
+ public final static double LOG_2PI = FastMath.log(2 * Math.PI);
+ public final static double ROOT_2PI = FastMath.sqrt(2 * Math.PI);
+ /**
+ * Euler's Gamma constant
+ */
+ public final static double GAMMA = 0.577215664901532860606512090;
+ public final static double HALF_LOG2_PI = 0.91893853320467274178032973640562;
+ final static double[] GAMMA_BOUNDS = new double[]{0.001, 12.0};
+ final static double HIGH_BOUND = 171.624;
+
+ /**
+ * numerator coefficients for approximation over the interval (1,2)
+ */
+ private final static double[] p = new double[]{
+ -1.71618513886549492533811E+0,
+ 2.47656508055759199108314E+1,
+ -3.79804256470945635097577E+2,
+ 6.29331155312818442661052E+2,
+ 8.66966202790413211295064E+2,
+ -3.14512729688483675254357E+4,
+ -3.61444134186911729807069E+4,
+ 6.64561438202405440627855E+4
+ };
+
+ /**
+ * denominator coefficients for approximation over the interval (1,2)
+ */
+ private final static double[] q = new double[]{
+ -3.08402300119738975254353E+1,
+ 3.15350626979604161529144E+2,
+ -1.01515636749021914166146E+3,
+ -3.10777167157231109440444E+3,
+ 2.25381184209801510330112E+4,
+ 4.75584627752788110767815E+3,
+ -1.34659959864969306392456E+5,
+ -1.15132259675553483497211E+5
+ };
+
+ /**
+ * Abramowitz and Stegun 6.1.41
+ * Asymptotic series should be good to at least 11 or 12 figures
+ * For error analysis, see Whittiker and Watson
+ * A Course in Modern Analysis (1927), page 252
+ */
+ private final static double[] c = new double[]{
+ 1.0 / 12.0,
+ -1.0 / 360.0,
+ 1.0 / 1260.0,
+ -1.0 / 1680.0,
+ 1.0 / 1188.0,
+ -691.0 / 360360.0,
+ 1.0 / 156.0,
+ -3617.0 / 122400.0
+ };
+
+ // Any assertion failures will cause exception to be thrown right away
+ static {
+ // These should never change
+ assert GAMMA_BOUNDS.length == 2;
+ assert p.length == 8;
+ assert p.length == q.length;
+ assert c.length == p.length;
+ }
+
+ /**
+ * Adapted from sklearn_gamma, which was in turn adapted from
+ * John D. Cook's public domain version of lgamma, from
+ * http://www.johndcook.com/stand_alone_code.html
+ *
+ * @param x
+ * @return
+ */
+ public static double gamma(double x) {
+ if (x <= 0)
+ throw new IllegalArgumentException("x must exceed 0");
+
+ // Check if in first boundary
+ int boundaryIdx = 0;
+ if (x < GAMMA_BOUNDS[boundaryIdx++])
+ return 1.0 / (x * (1.0 + GAMMA * x));
+
+ if (x < GAMMA_BOUNDS[boundaryIdx++]) {
+ double den = 1.0, num = 0.0, res, z, y = x;
+ int i, n = 0;
+ boolean lt1 = y < 1.0;
+
+ if (lt1)
+ y += 1.0;
+ else {
+ n = ((int) y) - 1;
+ y -= n;
+ }
+
+ z = y - 1;
+ for (i = 0; i < p.length; i++) {
+ num = (num + p[i]) * z;
+ den = den * z + q[i];
+ }
+
+ res = num / den + 1.0;
+
+ // Correction if arg was not initially in (1,2)
+ if (lt1)
+ res /= (y - 1.0);
+ else {
+ for (i = 0; i < n; i++, y++)
+ res *= y;
+ }
+
+ return res;
+ }
+
+ if (x > HIGH_BOUND)
+ return Double.POSITIVE_INFINITY;
+
+ return FastMath.exp(lgamma(x));
+ }
+
+ public static double lgamma(double x) {
+ if (x <= 0)
+ throw new IllegalArgumentException("x must exceed 0");
+
+ double z, sum;
+ int i;
+
+ if (x < GAMMA_BOUNDS[1])
+ return FastMath.log(FastMath.abs(gamma(x)));
+
+ z = 1.0 / (x * x);
+ sum = c[7];
+ for (i = 6; i >= 0; i--) {
+ sum *= z;
+ sum += c[i];
+ }
+
+ return (x - 0.5) * FastMath.log(x) - x + HALF_LOG2_PI + sum / x;
+ }
+ }
+
+
+ /**
+ * A class to hold configurations for parallelism
+ *
+ * @author Taylor G Smith
+ */
+ public abstract static class ParallelismConf {
+ /**
+ * Matrices with number of elements exceeding this number
+ * will automatically trigger parallel events as supported
+ * in clustering methods.
+ */
+ public static final int MIN_ELEMENTS = 15000;
+
+ /**
+ * The minimum number of cores to efficiently
+ * allow parallel operations.
+ */
+ public static final int MIN_PARALLEL_CORES_RECOMMENDED = 8;
+
+ /**
+ * The minimum number of required cores to allow any
+ * parallelism at all.
+ */
+ public static final int MIN_CORES_REQUIRED = 4;
+
+ /**
+ * The number of available cores on the machine. Used for determining
+ * whether or not to use parallelism & how large parallel chunks should be.
+ */
+ public static final int NUM_CORES = Runtime.getRuntime().availableProcessors();
+
+ /**
+ * Whether to allow parallelism at all or quietly force serial jobs where necessary
+ */
+ public static boolean PARALLELISM_ALLOWED = NUM_CORES >= MIN_CORES_REQUIRED;
+
+ /**
+ * Whether parallelization is recommended for this machine.
+ * Default value is true if availableProcessors is at least 8.
+ */
+ public static final boolean PARALLELISM_RECOMMENDED = NUM_CORES >= MIN_PARALLEL_CORES_RECOMMENDED;
+
+ /**
+ * If true and the size of the vector exceeds {@value #MAX_SERIAL_VECTOR_LEN},
+ * auto schedules parallel job for applicable operations. This can slow
+ * things down on machines with a lower core count, but speed them up
+ * on machines with a higher core count. More heap space may be required.
+ * Defaults to {@link #PARALLELISM_RECOMMENDED}
+ */
+ public static boolean ALLOW_AUTO_PARALLELISM = PARALLELISM_RECOMMENDED;
+
+ /**
+ * The global ForkJoin thread pool for parallel recursive tasks.
+ */
+ final static public ForkJoinPool FJ_THREADPOOL = new ForkJoinPool();
+
+ /**
+ * The max length a vector may be before defaulting to a parallel process, if applicable
+ */
+ static public int MAX_SERIAL_VECTOR_LEN = 10_000_000;
+
+ /**
+ * The max length a parallel-processed chunk may be
+ */
+ public static int MAX_PARALLEL_CHUNK_SIZE = MAX_SERIAL_VECTOR_LEN / NUM_CORES; //2_500_000;
+ }
+}
diff --git a/src/main/java/com/clust4j/NamedEntity.java b/src/main/java/com/clust4j/NamedEntity.java
new file mode 100644
index 00000000..d392016b
--- /dev/null
+++ b/src/main/java/com/clust4j/NamedEntity.java
@@ -0,0 +1,36 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j;
+
+/**
+ * Models or any {@link com.clust4j.log.Loggable}
+ * that should be able to "say their name" should
+ * implement this method.
+ *
+ *
Other considered names:
+ *
SelfProfessant
+ *
Parrot
+ *
EchoChamber
+ *
+ * :-)
+ *
+ *
+ *
+ * @author Taylor G Smith
+ */
+public interface NamedEntity {
+ public String getName();
+}
diff --git a/src/main/java/com/clust4j/algo/AbstractAutonomousClusterer.java b/src/main/java/com/clust4j/algo/AbstractAutonomousClusterer.java
new file mode 100644
index 00000000..6dc0256c
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/AbstractAutonomousClusterer.java
@@ -0,0 +1,58 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.metrics.scoring.SupervisedMetric;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import static com.clust4j.metrics.scoring.UnsupervisedMetric.SILHOUETTE;
+
+public abstract class AbstractAutonomousClusterer extends AbstractClusterer implements UnsupervisedClassifier {
+ /**
+ *
+ */
+ private static final long serialVersionUID = -4704891508225126315L;
+
+ public AbstractAutonomousClusterer(RealMatrix data, BaseClustererParameters planner) {
+ super(data, planner);
+ }
+
+ /**
+ * The number of clusters this algorithm identified
+ *
+ * @return the number of clusters in the system
+ */
+ abstract public int getNumberOfIdentifiedClusters();
+
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public double indexAffinityScore(int[] labels) {
+ // Propagates ModelNotFitException
+ return SupervisedMetric.INDEX_AFFINITY.evaluate(labels, getLabels());
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public double silhouetteScore() {
+ // Propagates ModelNotFitException
+ return SILHOUETTE.evaluate(this, getLabels());
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/AbstractCentroidClusterer.java b/src/main/java/com/clust4j/algo/AbstractCentroidClusterer.java
new file mode 100644
index 00000000..b8c2f21e
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/AbstractCentroidClusterer.java
@@ -0,0 +1,494 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.NamedEntity;
+import com.clust4j.kernel.Kernel;
+import com.clust4j.log.LogTimer;
+import com.clust4j.metrics.pairwise.Distance;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.metrics.scoring.SupervisedMetric;
+import com.clust4j.utils.MatUtils;
+import com.clust4j.utils.VecUtils;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Random;
+
+import static com.clust4j.metrics.scoring.UnsupervisedMetric.SILHOUETTE;
+
+public abstract class AbstractCentroidClusterer extends AbstractPartitionalClusterer
+ implements CentroidLearner, Convergeable, UnsupervisedClassifier {
+
+ private static final long serialVersionUID = -424476075361612324L;
+ final public static double DEF_CONVERGENCE_TOLERANCE = 0.005; // Not same as Convergeable.DEF_TOL
+ final public static int DEF_K = BaseNeighborsModel.DEF_K;
+ final public static InitializationStrategy DEF_INIT = InitializationStrategy.AUTO;
+ final public static HashSet> UNSUPPORTED_METRICS;
+
+ static {
+ UNSUPPORTED_METRICS = new HashSet<>();
+
+ /*
+ * Add all binary distances
+ */
+ for (Distance d : Distance.binaryDistances())
+ UNSUPPORTED_METRICS.add(d.getClass());
+
+ /*
+ * Kernels that conditional positive def or
+ * may propagate NaNs or Infs or 100% zeros
+ */
+
+ // should be handled now by returning just one cluster...
+ //UNSUPPORTED_METRICS.add(CauchyKernel.class);
+ //UNSUPPORTED_METRICS.add(CircularKernel.class);
+ //UNSUPPORTED_METRICS.add(GeneralizedMinKernel.class);
+ //UNSUPPORTED_METRICS.add(HyperbolicTangentKernel.class);
+ //UNSUPPORTED_METRICS.add(InverseMultiquadricKernel.class);
+ //UNSUPPORTED_METRICS.add(LogKernel.class);
+ //UNSUPPORTED_METRICS.add(MinKernel.class);
+ //UNSUPPORTED_METRICS.add(MultiquadricKernel.class);
+ //UNSUPPORTED_METRICS.add(PolynomialKernel.class);
+ //UNSUPPORTED_METRICS.add(PowerKernel.class);
+ //UNSUPPORTED_METRICS.add(SplineKernel.class);
+ }
+
+
+ protected InitializationStrategy init;
+ final protected int maxIter;
+ final protected double tolerance;
+ final protected int[] init_centroid_indices;
+ final protected int m;
+
+ volatile protected boolean converged = false;
+ volatile protected double tss = 0.0;
+ volatile protected double bss = Double.NaN;
+ volatile protected double[] wss;
+
+ volatile protected int[] labels = null;
+ volatile protected int iter = 0;
+
+ /**
+ * Key is the group label, value is the corresponding centroid
+ */
+ volatile protected ArrayList centroids = new ArrayList();
+
+
+ static interface Initializer {
+ int[] getInitialCentroidSeeds(AbstractCentroidClusterer model, double[][] X, int k, final Random seed);
+ }
+
+ public static enum InitializationStrategy implements java.io.Serializable, Initializer, NamedEntity {
+ AUTO {
+ @Override
+ public int[] getInitialCentroidSeeds(AbstractCentroidClusterer model, double[][] X, int k, final Random seed) {
+ if (model.dist_metric instanceof Kernel)
+ return RANDOM.getInitialCentroidSeeds(model, X, k, seed);
+ return KM_AUGMENTED.getInitialCentroidSeeds(model, X, k, seed);
+ }
+
+ @Override
+ public String getName() {
+ return "auto initialization";
+ }
+ },
+
+ /**
+ * Initialize {@link KMeans} or {@link KMedoids} with a set of randomly
+ * selected centroids to use as the initial seeds. This is the traditional
+ * initialization procedure in both KMeans and KMedoids and typically performs
+ * worse than using {@link InitializationStrategy#KM_AUGMENTED}
+ */
+ RANDOM {
+ @Override
+ public int[] getInitialCentroidSeeds(AbstractCentroidClusterer model, double[][] X, int k, final Random seed) {
+ model.init = this;
+ final int m = X.length;
+
+ // Corner case: k = m
+ if (m == k)
+ return VecUtils.arange(k);
+
+ final int[] recordIndices = VecUtils.permutation(VecUtils.arange(m), seed);
+ final int[] cent_indices = new int[k];
+ for (int i = 0; i < k; i++)
+ cent_indices[i] = recordIndices[i];
+ return cent_indices;
+ }
+
+ @Override
+ public String getName() {
+ return "random initialization";
+ }
+ },
+
+ /**
+ * Proposed in 2007 by David Arthur and Sergei Vassilvitskii, this k-means++ initialization
+ * algorithms is an approximation algorithm for the NP-hard k-means problem - a way of avoiding the
+ * sometimes poor clusterings found by the standard k-means algorithm.
+ *
+ * @see k-means++
+ * @see k-means++ paper
+ */
+ KM_AUGMENTED {
+ @Override
+ public int[] getInitialCentroidSeeds(AbstractCentroidClusterer model, double[][] X, int k, final Random seed) {
+ model.init = this;
+
+ final int m = X.length, n = X[0].length;
+ final int[] range = VecUtils.arange(k);
+ final double[][] centers = new double[k][n];
+ final int[] centerIdcs = new int[k];
+
+
+ // Corner case: k = m
+ if (m == k)
+ return range;
+
+ // First need to get row norms, which is equal to X * X => row sums
+ // True Euclidean norm would sqrt each term, but no need...
+ final double[] norms = new double[m];
+ for (int i = 0; i < m; i++)
+ for (int j = 0; j < X[i].length; j++)
+ norms[i] += X[i][j] * X[i][j];
+
+ // Arthur and Vassilvitskii reported that this helped
+ final int numTrials = FastMath.max(2 * (int) FastMath.log(k), 1);
+
+
+ // Start with a random center
+ int center_id = seed.nextInt(m);
+ centers[0] = X[center_id];
+ centerIdcs[0] = center_id;
+
+ // Initialize list of closest distances
+ double[][] closest = eucDists(new double[][]{centers[0]}, X);
+ double currentPotential = MatUtils.sum(closest);
+
+
+ // Pick the rest of the cluster starting points
+ double[] randomVals, cumSum;
+ int[] candidateIdcs;
+ double[][] candidateRows, distsToCandidates, bestDistSq;
+ int bestCandidate;
+ double bestPotential;
+
+
+ for (int i = 1; i < k; i++) { // if k == 1, will skip this
+
+ /*
+ * Generate some random vals. This is a precursor to choosing
+ * centroid candidates by sampling with probability proportional to
+ * partial distance to nearest existing centroid
+ */
+ randomVals = new double[numTrials];
+ for (int j = 0; j < randomVals.length; j++)
+ randomVals[j] = currentPotential * seed.nextDouble();
+
+
+ /* Search sorted and get new dists for candidates */
+ cumSum = MatUtils.cumSum(closest); // always will be sorted
+ candidateIdcs = searchSortedCumSum(cumSum, randomVals);
+
+ // Identify the candidates
+ candidateRows = new double[candidateIdcs.length][];
+ for (int j = 0; j < candidateRows.length; j++)
+ candidateRows[j] = X[candidateIdcs[j]];
+
+ // dists to candidates
+ distsToCandidates = eucDists(candidateRows, X);
+
+
+ // Identify best candidate...
+ bestCandidate = -1;
+ bestPotential = Double.POSITIVE_INFINITY;
+ bestDistSq = null;
+
+ for (int trial = 0; trial < numTrials; trial++) {
+ double[] trialCandidate = distsToCandidates[trial];
+ double[][] newDistSq = new double[closest.length][trialCandidate.length];
+
+ // Build min dist array
+ double newPotential = 0.0; // running sum
+ for (int j = 0; j < newDistSq.length; j++) {
+ for (int p = 0; p < trialCandidate.length; p++) {
+ newDistSq[j][p] = FastMath.min(closest[j][p], trialCandidate[p]);
+ newPotential += newDistSq[j][p];
+ }
+ }
+
+ // Store if best so far
+ if (-1 == bestCandidate || newPotential < bestPotential) {
+ bestCandidate = candidateIdcs[trial];
+ bestPotential = newPotential;
+ bestDistSq = newDistSq;
+ }
+ }
+
+
+ // Add the record...
+ centers[i] = X[bestCandidate];
+ centerIdcs[i] = bestCandidate;
+
+ // update vars outside loop
+ currentPotential = bestPotential;
+ closest = bestDistSq;
+ }
+
+
+ return centerIdcs;
+ }
+
+ @Override
+ public String getName() {
+ return "k-means++";
+ }
+ }
+ }
+
+ /**
+ * Internal method for cumsum searchsorted. Protected for testing only
+ */
+ static int[] searchSortedCumSum(double[] cumSum, double[] randomVals) {
+ final int[] populate = new int[randomVals.length];
+
+ for (int c = 0; c < populate.length; c++) {
+ populate[c] = cumSum.length - 1;
+
+ for (int cmsm = 0; cmsm < cumSum.length; cmsm++) {
+ if (randomVals[c] <= cumSum[cmsm]) {
+ populate[c] = cmsm;
+ break;
+ }
+ }
+ }
+
+ return populate;
+ }
+
+ /**
+ * Internal method for computing candidate distances. Protected for testing only
+ */
+ static double[][] eucDists(double[][] centers, double[][] X) {
+ MatUtils.checkDimsForUniformity(X);
+ MatUtils.checkDimsForUniformity(centers);
+
+ final int m = X.length, n = X[0].length;
+ if (n != centers[0].length)
+ throw new DimensionMismatchException(n, centers[0].length);
+
+ int next = 0;
+ final double[][] dists = new double[centers.length][m];
+ for (double[] d : centers) {
+ for (int i = 0; i < m; i++)
+ dists[next][i] = Distance.EUCLIDEAN.getPartialDistance(d, X[i]);
+ next++;
+ }
+
+ return dists;
+ }
+
+
+ public AbstractCentroidClusterer(RealMatrix data,
+ CentroidClustererParameters extends AbstractCentroidClusterer> planner) {
+ super(data, planner, planner.getK());
+
+ /*
+ * Check for prohibited dist metrics...
+ */
+ if (!isValidMetric(this.dist_metric)) {
+ warn(this.dist_metric.getName() + " is unsupported by " + getName() + "; "
+ + "falling back to default (" + defMetric().getName() + ")");
+
+ /*
+ * If this is KMedoids, we set it to Mahattan, otherwise Euclidean
+ */
+ this.setSeparabilityMetric(defMetric());
+ }
+
+ this.init = planner.getInitializationStrategy();
+ this.maxIter = planner.getMaxIter();
+ this.tolerance = planner.getConvergenceTolerance();
+ this.m = data.getRowDimension();
+
+ if (maxIter < 0) throw new IllegalArgumentException("maxIter must exceed 0");
+ if (tolerance < 0) throw new IllegalArgumentException("minChange must exceed 0");
+
+
+ // set centroids
+ final LogTimer centTimer = new LogTimer();
+ this.init_centroid_indices = init.getInitialCentroidSeeds(
+ this, this.data.getData(), k, getSeed());
+ for (int i : this.init_centroid_indices)
+ centroids.add(this.data.getRow(i));
+
+
+ info("selected centroid centers via " + init.getName() + " in " + centTimer.toString());
+ logModelSummary();
+
+ /*
+ * The TSS will always be the same -- the sum of squared distances from the mean record.
+ * We can just compute this here quick and easy.
+ */
+ final double[][] X = this.data.getDataRef();
+ final double[] mean_record = MatUtils.meanRecord(X);
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < mean_record.length; j++) {
+ double diff = X[i][j] - mean_record[j];
+ tss += (diff * diff);
+ }
+ }
+
+ // Initialize WSS:
+ wss = VecUtils.rep(Double.NaN, k);
+ }
+
+ @Override
+ final public boolean isValidMetric(GeometricallySeparable geo) {
+ return !UNSUPPORTED_METRICS.contains(geo.getClass());
+ }
+
+ @Override
+ final protected ModelSummary modelSummary() {
+ return new ModelSummary(new Object[]{
+ "Num Rows", "Num Cols", "Metric", "K", "Allow Par.", "Max Iter", "Tolerance", "Init."
+ }, new Object[]{
+ m, data.getColumnDimension(), getSeparabilityMetric(), k,
+ parallel,
+ maxIter, tolerance, init.toString()
+ });
+ }
+
+
+ @Override
+ public boolean didConverge() {
+ synchronized (fitLock) {
+ return converged;
+ }
+ }
+
+ @Override
+ public ArrayList getCentroids() {
+ synchronized (fitLock) {
+ final ArrayList cent = new ArrayList();
+ for (double[] d : centroids)
+ cent.add(VecUtils.copy(d));
+
+ return cent;
+ }
+ }
+
+ /**
+ * Returns a copy of the classified labels
+ */
+ @Override
+ public int[] getLabels() {
+ synchronized (fitLock) {
+ return super.handleLabelCopy(labels);
+ }
+ }
+
+ @Override
+ public int getMaxIter() {
+ return maxIter;
+ }
+
+ @Override
+ public double getConvergenceTolerance() {
+ return tolerance;
+ }
+
+ /**
+ * In the corner case that k = 1, the {@link LabelEncoder}
+ * won't work, so we need to label everything as 0 and immediately return
+ */
+ protected final void labelFromSingularK(final double[][] X) {
+ labels = VecUtils.repInt(0, m);
+ wss = new double[]{tss};
+ iter++;
+ converged = true;
+ warn("k=1; converged immediately with a TSS of " + tss);
+ }
+
+ @Override
+ public int itersElapsed() {
+ synchronized (fitLock) {
+ return iter;
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public double indexAffinityScore(int[] labels) {
+ // Propagates ModelNotFitException
+ return SupervisedMetric.INDEX_AFFINITY.evaluate(labels, getLabels());
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public int[] predict(RealMatrix newData) {
+ return CentroidUtils.predict(this, newData);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public double silhouetteScore() {
+ // Propagates ModelNotFitException
+ return SILHOUETTE.evaluate(this, getLabels());
+ }
+
+
+ public double getTSS() {
+ // doesn't need to be synchronized, because
+ // calculated in the constructor always
+ return tss;
+ }
+
+ public double[] getWSS() {
+ synchronized (fitLock) {
+ if (null == wss) {
+ return VecUtils.rep(Double.NaN, k);
+ } else {
+ return VecUtils.copy(wss);
+ }
+ }
+ }
+
+ public double getBSS() {
+ synchronized (fitLock) {
+ return bss;
+ }
+ }
+
+ protected abstract void reorderLabelsAndCentroids();
+
+ @Override
+ protected abstract AbstractCentroidClusterer fit();
+
+ protected GeometricallySeparable defMetric() {
+ return AbstractClusterer.DEF_DIST;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/AbstractClusterer.java b/src/main/java/com/clust4j/algo/AbstractClusterer.java
new file mode 100644
index 00000000..d7a7a465
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/AbstractClusterer.java
@@ -0,0 +1,487 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.GlobalState;
+import com.clust4j.NamedEntity;
+import com.clust4j.except.ModelNotFitException;
+import com.clust4j.except.NaNException;
+import com.clust4j.kernel.Kernel;
+import com.clust4j.log.Log;
+import com.clust4j.log.LogTimer;
+import com.clust4j.log.Loggable;
+import com.clust4j.metrics.pairwise.Distance;
+import com.clust4j.metrics.pairwise.DistanceMetric;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.metrics.pairwise.SimilarityMetric;
+import com.clust4j.utils.MatUtils;
+import com.clust4j.utils.TableFormatter.Table;
+import com.clust4j.utils.VecUtils;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.UUID;
+
+/**
+ * The highest level of cluster abstraction in clust4j, AbstractClusterer
+ * provides the interface for classifier clustering (both supervised and unsupervised).
+ * It also provides all the functionality for any BaseClustererPlanner classes and logging.
+ *
+ * @author Taylor G Smith <tgsmith61591@gmail.com>
+ */
+public abstract class AbstractClusterer
+ extends BaseModel
+ implements Loggable, NamedEntity, java.io.Serializable, MetricValidator {
+
+ private static final long serialVersionUID = -3623527903903305017L;
+
+ /**
+ * Whether algorithms should by default behave in a verbose manner
+ */
+ public static boolean DEF_VERBOSE = false;
+
+ /**
+ * By default, uses the {@link GlobalState#DEFAULT_RANDOM_STATE}
+ */
+ protected final static Random DEF_SEED = GlobalState.DEFAULT_RANDOM_STATE;
+ final public static GeometricallySeparable DEF_DIST = Distance.EUCLIDEAN;
+ /**
+ * The model id
+ */
+ final private String modelKey;
+
+
+ /**
+ * Underlying data
+ */
+ final protected Array2DRowRealMatrix data;
+ /**
+ * Similarity metric
+ */
+ protected GeometricallySeparable dist_metric;
+ /**
+ * Seed for any shuffles
+ */
+ protected final Random random_state;
+ /**
+ * Verbose for heavily logging
+ */
+ final private boolean verbose;
+ /**
+ * Whether to use parallelism
+ */
+ protected final boolean parallel;
+ /**
+ * Whether the entire matrix is comprised of only one unique value
+ */
+ protected boolean singular_value;
+
+
+ /**
+ * Have any warnings occurred -- volatile because can change
+ */
+ volatile private boolean hasWarnings = false;
+ final private ArrayList warnings = new ArrayList<>();
+ protected final ModelSummary fitSummary;
+
+
+ /**
+ * Build a new instance from another caller
+ *
+ * @param caller
+ */
+ protected AbstractClusterer(AbstractClusterer caller) {
+ this(caller, null);
+ }
+
+ /**
+ * Internal constructor giving precedence to the planning class if not null
+ *
+ * @param caller
+ * @param planner
+ */
+ protected AbstractClusterer(AbstractClusterer caller, BaseClustererParameters planner) {
+ this.dist_metric = null == planner ? caller.dist_metric : planner.getMetric();
+ this.verbose = null == planner ? false : planner.getVerbose(); // if another caller, default to false
+ this.modelKey = getName() + "_" + UUID.randomUUID();
+ this.random_state = null == planner ? caller.random_state : planner.getSeed();
+ this.data = caller.data; // Use the reference
+ this.parallel = caller.parallel;
+ this.fitSummary = new ModelSummary(getModelFitSummaryHeaders());
+ this.singular_value = caller.singular_value;
+ }
+
+ protected AbstractClusterer(RealMatrix data, BaseClustererParameters planner, boolean as_is) {
+
+ this.dist_metric = planner.getMetric();
+ this.verbose = planner.getVerbose();
+ this.modelKey = getName() + "_" + UUID.randomUUID();
+ this.random_state = planner.getSeed();
+
+ // Determine whether we should parallelize
+ this.parallel = planner.getParallel() && GlobalState.ParallelismConf.PARALLELISM_ALLOWED;
+
+ /*
+ * If user tried to force serial, but we just can't...
+ */
+ if (!parallel && planner.getParallel())
+ info("min num cores required for parallel: " + GlobalState.ParallelismConf.MIN_CORES_REQUIRED);
+
+ if (this.dist_metric instanceof Kernel)
+ warn("running " + getName() + " in Kernel mode can be an expensive option");
+
+ // Handle data, now...
+ this.data = as_is ?
+ (Array2DRowRealMatrix) data : // internally, always 2d...
+ initData(data);
+ if (singular_value)
+ warn("all elements in input matrix are equal (" + data.getEntry(0, 0) + ")");
+
+ this.fitSummary = new ModelSummary(getModelFitSummaryHeaders());
+ }
+
+ /**
+ * Base clusterer constructor. Sets up the distance measure,
+ * and if necessary scales data.
+ *
+ * @param data
+ * @param planner
+ */
+ protected AbstractClusterer(RealMatrix data, BaseClustererParameters planner) {
+ this(data, planner, false);
+ }
+
+
+ final private Array2DRowRealMatrix initData(final RealMatrix data) {
+ final int m = data.getRowDimension(), n = data.getColumnDimension();
+ final double[][] ref = new double[m][n];
+ final HashSet unique = new HashSet<>();
+
+ // Used to compute variance on the fly for summaries later...
+ double[] sum = new double[n];
+ double[] sumSq = new double[n];
+ double[] maxes = VecUtils.rep(Double.NEGATIVE_INFINITY, n);
+ double[] mins = VecUtils.rep(Double.POSITIVE_INFINITY, n);
+
+ // This will store summaries for each column + a header
+ ModelSummary summaries = new ModelSummary(new Object[]{
+ "Feature #", "Variance", "Std. Dev", "Mean", "Max", "Min"
+ });
+
+ /*
+ * Internally performs the copy
+ */
+ double entry;
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < n; j++) {
+ entry = data.getEntry(i, j);
+
+ if (Double.isNaN(entry)) {
+ error(new NaNException("NaN in input data. "
+ + "Select a matrix imputation method for "
+ + "incomplete records"));
+ } else {
+ // copy the entry
+ ref[i][j] = entry;
+ unique.add(entry);
+
+ // capture stats...
+ sumSq[j] += entry * entry;
+ sum[j] += entry;
+ maxes[j] = FastMath.max(entry, maxes[j]);
+ mins[j] = FastMath.min(entry, mins[j]);
+
+ // if it's the last row, we can compute these:
+ if (i == m - 1) {
+ double var = (sumSq[j] - (sum[j] * sum[j]) / (double) m) / ((double) m - 1.0);
+ if (var == 0) {
+ warn("zero variance in feature " + j);
+ }
+
+ summaries.add(new Object[]{
+ j, // feature num
+ var, // var
+ m < 2 ? Double.NaN : FastMath.sqrt(var), // std dev
+ sum[j] / (double) m, // mean
+ maxes[j], // max
+ mins[j] // min
+ });
+ }
+ }
+ }
+ }
+
+ // Log the summaries
+ summaryLogger(formatter.format(summaries));
+
+ if (unique.size() == 1)
+ this.singular_value = true;
+
+ /*
+ * Don't need to copy again, because already internally copied...
+ */
+ return new Array2DRowRealMatrix(ref, false);
+ }
+
+
+ /**
+ * A model must have the same key, data and class name
+ * in order to equal another model. It is extremely unlikely
+ * that a model will share a UUID with another. In fact, the probability
+ * of one duplicate would be about 50% if every person on
+ * Earth owned 600 million UUIDs.
+ */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o instanceof AbstractClusterer) {
+ AbstractClusterer a = (AbstractClusterer) o;
+ if (!this.getKey().equals(a.getKey()))
+ return false;
+
+ return MatUtils.equalsExactly(this.data.getDataRef(), a.data.getDataRef())
+ && this.getClass().equals(a.getClass())
+ //&& this.hashCode() == a.hashCode()
+ ;
+ }
+
+ return false;
+ }
+
+ /**
+ * Handles all label copies and ModelNotFitExceptions.
+ * This should be called within getLabels() operations
+ *
+ * @param data
+ * @param shuffleOrder
+ * @return
+ */
+ protected int[] handleLabelCopy(int[] labels) throws ModelNotFitException {
+ if (null == labels) {
+ error(new ModelNotFitException("model has not been fit yet"));
+ return null;
+ } else {
+ return VecUtils.copy(labels);
+ }
+ }
+
+ /**
+ * Copies the underlying AbstractRealMatrix datastructure
+ * and returns the clone so as to prevent accidental referential
+ * alterations of the data.
+ *
+ * @return copy of data
+ */
+ public RealMatrix getData() {
+ return data.copy();
+ }
+
+
+ /**
+ * Returns the separability metric used to assess vector similarity/distance
+ *
+ * @return distance metric
+ */
+ public GeometricallySeparable getSeparabilityMetric() {
+ return dist_metric;
+ }
+
+
+ /**
+ * Get the current seed being used for random state
+ *
+ * @return the random seed
+ */
+ public Random getSeed() {
+ return random_state;
+ }
+
+ /**
+ * Whether the algorithm resulted in any warnings
+ *
+ * @return whether the clustering effort has generated any warnings
+ */
+ @Override
+ public boolean hasWarnings() {
+ return hasWarnings;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = 17;
+ return result
+ ^ (verbose ? 1 : 0)
+ ^ (getKey().hashCode())
+ ^ (dist_metric instanceof DistanceMetric ? 31 :
+ dist_metric instanceof SimilarityMetric ? 53 : 1)
+ // ^ (hasWarnings ? 1 : 0) // removed because forces state dependency
+ ^ random_state.hashCode()
+ ^ data.hashCode();
+ }
+
+
+ /**
+ * Get the model key, the model's unique UUID
+ *
+ * @return the model's unique UUID
+ */
+ public String getKey() {
+ return modelKey;
+ }
+
+
+ /**
+ * Get the state of the model's verbosity
+ *
+ * @return is the model set to verbose mode or not?
+ */
+ public boolean getVerbose() {
+ return verbose;
+ }
+
+ /**
+ * Returns a collection of warnings if there are any, otherwise null
+ *
+ * @return
+ */
+ final public Collection getWarnings() {
+ return warnings.isEmpty() ? null : warnings;
+ }
+
+
+ /* -- LOGGER METHODS -- */
+ @Override
+ public void error(String msg) {
+ if (verbose) Log.err(getLoggerTag(), msg);
+ }
+
+ @Override
+ public void error(RuntimeException thrown) {
+ error(thrown.getMessage());
+ throw thrown;
+ }
+
+ @Override
+ public void warn(String msg) {
+ hasWarnings = true;
+ warnings.add(msg);
+ if (verbose) Log.warn(getLoggerTag(), msg);
+ }
+
+ @Override
+ public void info(String msg) {
+ if (verbose) Log.info(getLoggerTag(), msg);
+ }
+
+ @Override
+ public void trace(String msg) {
+ if (verbose) Log.trace(getLoggerTag(), msg);
+ }
+
+ @Override
+ public void debug(String msg) {
+ if (verbose) Log.debug(getLoggerTag(), msg);
+ }
+
+ /**
+ * Write the time the algorithm took to complete
+ *
+ * @param timer
+ */
+ @Override
+ public void sayBye(final LogTimer timer) {
+ logFitSummary();
+ info("model " + getKey() + " fit completed in " + timer.toString());
+ }
+
+ /**
+ * Used for logging the initialization summary.
+ */
+ private void logFitSummary() {
+ info("--");
+ info("Model Fit Summary:");
+ final Table tab = formatter.format(fitSummary);
+ summaryLogger(tab);
+ }
+
+ /**
+ * Used for logging the initialization summary
+ */
+ protected final void logModelSummary() {
+ info("--");
+ info("Model Init Summary:");
+ final Table tab = formatter.format(modelSummary());
+ summaryLogger(tab);
+ }
+
+ /**
+ * Handles logging of tables
+ */
+ final private void summaryLogger(Table tab) {
+ final String fmt = tab.toString();
+ final String sep = System.getProperty("line.separator");
+ final String[] summary = fmt.split(sep);
+
+ // Sometimes the fit summary can be overwhelmingly long..
+ // Only want to show top few & bottom few. (extra 1 on top for header)
+ final int top = 6, bottom = top - 1;
+ int topThresh = top, bottomThresh;
+ if (summary.length > top + bottom) {
+ // calculate the bottom thresh
+ bottomThresh = summary.length - bottom;
+ } else {
+ topThresh = summary.length;
+ bottomThresh = 0;
+ }
+
+
+ int iter = 0;
+ boolean shownBreak = false;
+ for (String line : summary) {
+ if (iter < topThresh || iter > bottomThresh)
+ info(line);
+ else if (!shownBreak) {
+ // first after top thresh
+ info(tab.getTableBreak());
+ shownBreak = true;
+ }
+
+ iter++;
+ }
+ }
+
+ protected void setSeparabilityMetric(final GeometricallySeparable sep) {
+ this.dist_metric = sep;
+ }
+
+
+ /**
+ * Fits the model
+ */
+ @Override
+ abstract protected AbstractClusterer fit();
+
+ protected abstract ModelSummary modelSummary();
+
+ protected abstract Object[] getModelFitSummaryHeaders();
+}
diff --git a/src/main/java/com/clust4j/algo/AbstractDBSCAN.java b/src/main/java/com/clust4j/algo/AbstractDBSCAN.java
new file mode 100644
index 00000000..ecd338b7
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/AbstractDBSCAN.java
@@ -0,0 +1,57 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import org.apache.commons.math3.linear.RealMatrix;
+
+abstract class AbstractDBSCAN extends AbstractDensityClusterer implements NoiseyClusterer {
+ private static final long serialVersionUID = 5247910788105653778L;
+
+ final public static double DEF_EPS = 0.5;
+ final public static int DEF_MIN_PTS = 5;
+
+ final protected int minPts;
+ protected double eps = DEF_EPS;
+
+ public AbstractDBSCAN(RealMatrix data, AbstractDBSCANParameters extends AbstractDBSCAN> planner) {
+ super(data, planner);
+
+ this.minPts = planner.getMinPts();
+
+ if (this.minPts < 1)
+ throw new IllegalArgumentException("minPts must be greater than 0");
+ }
+
+ abstract public static class AbstractDBSCANParameters
+ extends BaseClustererParameters
+ implements UnsupervisedClassifierParameters {
+ private static final long serialVersionUID = 765572960123009344L;
+ protected int minPts = DEF_MIN_PTS;
+
+ abstract public AbstractDBSCANParameters setMinPts(final int minPts);
+
+ final public int getMinPts() {
+ return minPts;
+ }
+ }
+
+ public int getMinPts() {
+ return minPts;
+ }
+
+ @Override
+ protected abstract AbstractDBSCAN fit();
+}
diff --git a/src/main/java/com/clust4j/algo/AbstractDensityClusterer.java b/src/main/java/com/clust4j/algo/AbstractDensityClusterer.java
new file mode 100644
index 00000000..99eabd92
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/AbstractDensityClusterer.java
@@ -0,0 +1,43 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.metrics.pairwise.SimilarityMetric;
+import org.apache.commons.math3.linear.RealMatrix;
+
+public abstract class AbstractDensityClusterer extends AbstractAutonomousClusterer {
+ /**
+ *
+ */
+ private static final long serialVersionUID = 5645721633522621894L;
+
+ public AbstractDensityClusterer(RealMatrix data, BaseClustererParameters planner) {
+ super(data, planner);
+
+ checkState(this);
+ } // End constructor
+
+ protected static void checkState(AbstractClusterer ac) {
+ // Should not use similarity metrics in DBClusterers, DB looks for
+ // neighborhoods not accurately represented via similarity metrics.
+ if (ac.getSeparabilityMetric() instanceof SimilarityMetric) {
+ ac.warn("density or radius-based clustering algorithms "
+ + "should use distance metrics instead of similarity metrics. "
+ + "Falling back to default: " + DEF_DIST);
+ ac.setSeparabilityMetric(DEF_DIST);
+ }
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/AbstractPartitionalClusterer.java b/src/main/java/com/clust4j/algo/AbstractPartitionalClusterer.java
new file mode 100644
index 00000000..de3d99bf
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/AbstractPartitionalClusterer.java
@@ -0,0 +1,51 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import org.apache.commons.math3.linear.RealMatrix;
+
+public abstract class AbstractPartitionalClusterer extends AbstractClusterer {
+ /**
+ *
+ */
+ private static final long serialVersionUID = 8489725366968682469L;
+ /**
+ * The number of clusters to find. This field is not final, as in
+ * some corner cases, the algorithm will modify k for convergence.
+ */
+ protected int k;
+
+ public AbstractPartitionalClusterer(
+ RealMatrix data,
+ BaseClustererParameters planner,
+ final int k) {
+ super(data, planner);
+
+ if (k < 1)
+ error(new IllegalArgumentException("k must exceed 0"));
+ if (k > data.getRowDimension())
+ error(new IllegalArgumentException("k exceeds number of records"));
+
+ this.k = this.singular_value ? 1 : k;
+ if (this.singular_value && k != 1) {
+ warn("coerced k to 1 due to equality of all elements in input matrix");
+ }
+ } // End constructor
+
+ public int getK() {
+ return k;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/AffinityPropagation.java b/src/main/java/com/clust4j/algo/AffinityPropagation.java
new file mode 100644
index 00000000..d604d983
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/AffinityPropagation.java
@@ -0,0 +1,846 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.GlobalState;
+import com.clust4j.except.ModelNotFitException;
+import com.clust4j.log.Log.Tag.Algo;
+import com.clust4j.log.LogTimer;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.utils.MatUtils;
+import com.clust4j.utils.MatUtils.Axis;
+import com.clust4j.utils.VecUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Random;
+
+/**
+ * Affinity Propagation (AP)
+ * is a clustering algorithm based on the concept of "message passing" between data points.
+ * Unlike other clustering algorithms such as {@link KMeans} or {@link KMedoids},
+ * AP does not require the number of clusters to be determined or estimated before
+ * running the algorithm. Like KMedoids, AP finds "exemplars", members of the input
+ * set that are representative of clusters.
+ *
+ * @author Taylor G Smith <tgsmith61591@gmail.com>, adapted from sklearn Python implementation
+ * @see sklearn
+ */
+final public class AffinityPropagation extends AbstractAutonomousClusterer implements Convergeable, CentroidLearner {
+ private static final long serialVersionUID = 1986169131867013043L;
+
+ /**
+ * The number of stagnant iterations after which the algorithm will declare convergence
+ */
+ final public static int DEF_ITER_BREAK = 15;
+ final public static int DEF_MAX_ITER = 200;
+ final public static double DEF_DAMPING = 0.5;
+ /**
+ * By default uses minute Gaussian smoothing. It is recommended this remain
+ * true, but the {@link AffinityPropagationParameters#useGaussianSmoothing(boolean)}
+ * method can disable this option
+ */
+ final public static boolean DEF_ADD_GAUSSIAN_NOISE = true;
+ final public static HashSet> UNSUPPORTED_METRICS;
+
+
+ /**
+ * Static initializer
+ */
+ static {
+ UNSUPPORTED_METRICS = new HashSet<>();
+
+ /*
+ * can produce negative inf, but should be OK:
+ * UNSUPPORTED_METRICS.add(CircularKernel.class);
+ * UNSUPPORTED_METRICS.add(LogKernel.class);
+ */
+
+ // Add more metrics here if necessary...
+ }
+
+ @Override
+ final public boolean isValidMetric(GeometricallySeparable geo) {
+ return !UNSUPPORTED_METRICS.contains(geo.getClass());
+ }
+
+
+ /**
+ * Damping factor
+ */
+ private final double damping;
+
+ /**
+ * Remove degeneracies with noise?
+ */
+ private final boolean addNoise;
+
+ /**
+ * Number of stagnant iters after which to break
+ */
+ private final int iterBreak;
+
+ /**
+ * The max iterations
+ */
+ private final int maxIter;
+
+ /**
+ * Num rows, cols
+ */
+ private final int m;
+
+ /**
+ * Min change convergence criteria
+ */
+ private final double tolerance;
+
+ /**
+ * Class labels
+ */
+ private volatile int[] labels = null;
+
+ /**
+ * Track convergence
+ */
+ private volatile boolean converged = false;
+
+ /**
+ * Number of identified clusters
+ */
+ private volatile int numClusters;
+
+ /**
+ * Count iterations
+ */
+ private volatile int iterCt = 0;
+
+ /**
+ * Sim matrix. Only use during fitting, then back to null to save space
+ */
+ private volatile double[][] sim_mat = null;
+
+ /**
+ * Holds the centroids
+ */
+ private volatile ArrayList centroids = null;
+
+ /**
+ * Holds centroid indices
+ */
+ private volatile ArrayList centroidIndices = null;
+
+ /**
+ * Holds the availability matrix
+ */
+ volatile private double[][] cachedA;
+
+ /**
+ * Holds the responsibility matrix
+ */
+ volatile private double[][] cachedR;
+
+
+ /**
+ * Initializes a new AffinityPropagationModel with default parameters
+ *
+ * @param data
+ */
+ protected AffinityPropagation(final RealMatrix data) {
+ this(data, new AffinityPropagationParameters());
+ }
+
+ /**
+ * Initializes a new AffinityPropagationModel with parameters
+ *
+ * @param data
+ * @param planner
+ */
+ public AffinityPropagation(final RealMatrix data, final AffinityPropagationParameters planner) {
+ super(data, planner);
+
+
+ // Check some args
+ if (planner.damping < DEF_DAMPING || planner.damping >= 1)
+ error(new IllegalArgumentException("damping "
+ + "must be between " + DEF_DAMPING + " and 1"));
+
+ this.damping = planner.damping;
+ this.iterBreak = planner.iterBreak;
+ this.m = data.getRowDimension();
+ this.tolerance = planner.minChange;
+ this.maxIter = planner.maxIter;
+ this.addNoise = planner.addNoise;
+
+ if (maxIter < 0) throw new IllegalArgumentException("maxIter must exceed 0");
+ if (tolerance < 0) throw new IllegalArgumentException("minChange must exceed 0");
+ if (iterBreak < 0) throw new IllegalArgumentException("iterBreak must exceed 0");
+
+ if (!addNoise) {
+ warn("not scaling with Gaussian noise can cause the algorithm not to converge");
+ }
+
+ /*
+ * Shouldn't be an issue with AP
+ */
+ if (!isValidMetric(this.dist_metric)) {
+ warn(this.dist_metric.getName() + " is not valid for " + getName() + ". "
+ + "Falling back to default Euclidean dist");
+ setSeparabilityMetric(DEF_DIST);
+ }
+
+ logModelSummary();
+ }
+
+ @Override
+ final protected ModelSummary modelSummary() {
+ return new ModelSummary(new Object[]{
+ "Num Rows", "Num Cols", "Metric", "Damping", "Allow Par.", "Max Iter", "Tolerance", "Add Noise"
+ }, new Object[]{
+ m, data.getColumnDimension(), getSeparabilityMetric(), damping,
+ parallel,
+ maxIter, tolerance, addNoise
+ });
+ }
+
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o instanceof AffinityPropagation) {
+ AffinityPropagation a = (AffinityPropagation) o;
+
+ /*
+ * This should apply to cachedR as well, so no
+ * need to check that lest we uselessly impose
+ * less coverage. This is also a litmus test of
+ * whether the model has been fit yet.
+ */
+ if (null == this.cachedA ^ null == a.cachedA)
+ return false;
+
+ return super.equals(o) // check on UUID and class
+ && MatUtils.equalsExactly(this.data.getDataRef(), a.data.getDataRef())
+ && VecUtils.equalsExactly(this.labels, a.labels)
+ && this.tolerance == a.tolerance
+ && this.addNoise == a.addNoise
+ && this.maxIter == a.maxIter
+ && this.damping == a.damping;
+ }
+
+ return false;
+ }
+
+ @Override
+ public int[] getLabels() {
+ return super.handleLabelCopy(labels);
+ }
+
+ @Override
+ public boolean didConverge() {
+ return converged;
+ }
+
+ public double[][] getAvailabilityMatrix() {
+ if (null != cachedA)
+ return MatUtils.copy(cachedA);
+ throw new ModelNotFitException("model is not fit");
+ }
+
+ public double[][] getResponsibilityMatrix() {
+ if (null != cachedR)
+ return MatUtils.copy(cachedR);
+ throw new ModelNotFitException("model is not fit");
+ }
+
+ @Override
+ public int getMaxIter() {
+ return maxIter;
+ }
+
+ @Override
+ public double getConvergenceTolerance() {
+ return tolerance;
+ }
+
+ @Override
+ public int itersElapsed() {
+ return iterCt;
+ }
+
+ @Override
+ public String getName() {
+ return "AffinityPropagation";
+ }
+
+ @Override
+ public Algo getLoggerTag() {
+ return com.clust4j.log.Log.Tag.Algo.AFFINITY_PROP;
+ }
+
+ /**
+ * Remove this from scope of {@link #fit()} to avoid lots of large objects
+ * left in memory. This is more space efficient and promotes easier testing.
+ *
+ * @param X
+ * @param metric
+ * @param seed
+ * @param addNoise
+ * @return the smoothed similarity matrix
+ */
+ protected static double[][] computeSmoothedSimilarity(final double[][] X, GeometricallySeparable metric, Random seed, boolean addNoise) {
+ /*
+ * Originally, we computed similarity matrix, then refactored the diagonal vector, and
+ * then computed the following portions. We can do this all at once and save lots of passes
+ * (5?) on the order of O(M^2), condensing it all to one pass of O(M choose 2).
+ *
+ * After the sim matrix is computed, we need to do three things:
+ *
+ * 1. Create a matrix of very small values (tiny_scaled) to remove degeneracies in sim_mal
+ * 2. Multiply tiny_scaled by an extremely small value (GlobalState.Mathematics.TINY*100)
+ * 3. Create a noise matrix of random Gaussian values and add it to the similarity matrix.
+ *
+ * The methods exist to build these in three to five separate O(M^2) passes, but that's
+ * extremely expensive, so we're going to do it in one giant, convoluted loop. If you're
+ * trying to debug this, sorry...
+ *
+ * Total runtime: O(2M * M choose 2)
+ */
+ final int m = X.length;
+ double[][] sim_mat = new double[m][m];
+
+ int idx = 0;
+ final double tiny_val = GlobalState.Mathematics.TINY * 100;
+ final double[] vector = new double[m * m];
+ double sim, noise;
+ boolean last_iter = false;
+
+
+ // Do this a little differently... set the diagonal FIRST.
+ for (int i = 0; i < m; i++) {
+ sim = -(metric.getPartialDistance(X[i], X[i]));
+ sim_mat[i][i] = sim;
+ vector[idx++] = sim;
+ }
+
+
+ for (int i = 0; i < m - 1; i++) {
+ for (int j = i + 1; j < m; j++) { // Upper triangular
+ sim = -(metric.getPartialDistance(X[i], X[j])); // similarity
+
+ // Assign to upper and lower portion
+ sim_mat[i][j] = sim;
+ sim_mat[j][i] = sim;
+
+ // Add to the vector (twice)
+ for (int b = 0; b < 2; b++)
+ vector[idx++] = sim;
+
+ // Catch the last iteration, compute the pref:
+ double median = 0.0;
+ if (last_iter = (i == m - 2 && j == m - 1))
+ median = VecUtils.median(vector);
+
+ if (addNoise) {
+ noise = (sim * GlobalState.Mathematics.EPS + tiny_val);
+ sim_mat[i][j] += (noise * seed.nextGaussian());
+ sim_mat[j][i] += (noise * seed.nextGaussian());
+
+ if (last_iter) { // set diag and do the noise thing.
+ noise = (median * GlobalState.Mathematics.EPS + tiny_val);
+ for (int h = 0; h < m; h++)
+ sim_mat[h][h] = median + (noise * seed.nextGaussian());
+ }
+ } else if (last_iter) {
+ // it's the last iter and no noise. Just set diag.
+ for (int h = 0; h < m; h++)
+ sim_mat[h][h] = median;
+ }
+ }
+ }
+
+ return sim_mat;
+ }
+
+
+ /**
+ * Computes the first portion of the AffinityPropagation iteration
+ * sequence in place. Separating this piece from the {@link #fit()} method
+ * itself allows for easier testing.
+ *
+ * @param A
+ * @param S
+ * @param tmp
+ * @param I
+ * @param Y
+ * @param Y2
+ */
+ protected static void affinityPiece1(double[][] A, double[][] S, double[][] tmp, int[] I, double[] Y, double[] Y2) {
+ final int m = S.length;
+
+ // Reassign tmp, create vector of arg maxes. Can
+ // assign tmp like this:
+ //
+ // tmp = MatUtils.add(A, sim_mat);
+ //
+ //
+ // But requires extra M x M pass. Also get indices of ROW max.
+ // Can do like this:
+ //
+ // I = MatUtils.argMax(tmp, Axis.ROW);
+ //
+ // But requires extra pass on order of M. Finally, capture the second
+ // highest record in each row, and store in a vector. Then row-wise
+ // scalar subtract Y from the sim_mat
+ for (int i = 0; i < m; i++) {
+
+ // Compute row maxes
+ double runningMax = Double.NEGATIVE_INFINITY;
+ double secondMax = Double.NEGATIVE_INFINITY;
+ int runningMaxIdx = 0; //-1; // Idx of max row element -- start at 0 in case metric produces -Infs
+
+ for (int j = 0; j < m; j++) { // Create tmp as A + sim_mat
+ tmp[i][j] = A[i][j] + S[i][j];
+
+ if (tmp[i][j] > runningMax) {
+ secondMax = runningMax;
+ runningMax = tmp[i][j];
+ runningMaxIdx = j;
+ } else if (tmp[i][j] > secondMax) {
+ secondMax = tmp[i][j];
+ }
+ }
+
+ I[i] = runningMaxIdx; // Idx of max element for row
+ Y[i] = tmp[i][I[i]]; // Grab the current val
+ Y2[i] = secondMax;
+ tmp[i][I[i]] = Double.NEGATIVE_INFINITY; // Set that idx to neg inf now
+ }
+ }
+
+ /**
+ * Computes the second portion of the AffinityPropagation iteration
+ * sequence in place. Separating this piece from the {@link #fit()} method
+ * itself allows for easier testing.
+ *
+ * @param colSums
+ * @param tmp
+ * @param I
+ * @param S
+ * @param R
+ * @param Y
+ * @param Y2
+ * @param damping
+ */
+ protected static void affinityPiece2(double[] colSums, double[][] tmp, int[] I,
+ double[][] S, double[][] R, double[] Y, double[] Y2, double damping) {
+
+ final int m = S.length;
+
+ // Second i thru m loop, get new max vector and then first damping.
+ // First damping ====================================
+ // This can be done like this (which is more readable):
+ //
+ // tmp = MatUtils.scalarMultiply(tmp, 1 - damping);
+ // R = MatUtils.scalarMultiply(R, damping);
+ // R = MatUtils.add(R, tmp);
+ //
+ // But it requires two extra MXM passes, which can be costly...
+ // We know R & tmp are both m X m, so we can combine the
+ // three steps all together...
+ // Finally, compute availability -- start by setting anything
+ // less than 0 to 0 in tmp. Also calc column sums in same pass...
+ int ind = 0;
+ final double omd = 1.0 - damping;
+
+ for (int i = 0; i < m; i++) {
+ // Get new max vector
+ for (int j = 0; j < m; j++)
+ tmp[i][j] = S[i][j] - Y[i];
+ tmp[ind][I[i]] = S[ind][I[i]] - Y2[ind++];
+
+ // Perform damping, then piecewise
+ // calculate column sums
+ for (int j = 0; j < m; j++) {
+ tmp[i][j] *= omd;
+ R[i][j] = (R[i][j] * damping) + tmp[i][j];
+
+ tmp[i][j] = FastMath.max(R[i][j], 0);
+ if (i != j) // Because we set diag after this outside j loop
+ colSums[j] += tmp[i][j];
+ }
+
+ tmp[i][i] = R[i][i]; // Set diagonal elements in tmp equal to those in R
+ colSums[i] += tmp[i][i];
+ }
+ }
+
+ /**
+ * Computes the third portion of the AffinityPropagation iteration
+ * sequence in place. Separating this piece from the {@link #fit()} method
+ * itself allows for easier testing.
+ *
+ * @param tmp
+ * @param colSums
+ * @param A
+ * @param R
+ * @param mask
+ * @param damping
+ */
+ protected static void affinityPiece3(double[][] tmp, double[] colSums,
+ double[][] A, double[][] R, double[] mask, double damping) {
+ final int m = A.length;
+
+ // Set any negative values to zero but keep diagonal at original
+ // Originally ran this way, but costs an extra M x M operation:
+ // tmp = MatUtils.scalarSubtract(tmp, colSums, Axis.COL);
+ // Finally, more damping...
+ // More damping ====================================
+ // This can be done like this (which is more readable):
+ //
+ // tmp = MatUtils.scalarMultiply(tmp, 1 - damping);
+ // A = MatUtils.scalarMultiply(A, damping);
+ // A = MatUtils.subtract(A, tmp);
+ //
+ // But it requires two extra MXM passes, which can be costly... O(2M^2)
+ // We know A & tmp are both m X m, so we can combine the
+ // three steps all together...
+
+ // ALSO CHECK CONVERGENCE CRITERIA
+
+ // Check convergence criteria =====================
+ // This can be done like this for readability:
+ //
+ // final double[] diagA = MatUtils.diagFromSquare(A);
+ // final double[] diagR = MatUtils.diagFromSquare(R);
+ // final double[] mask = new double[diagA.length];
+ // for(int i = 0; i < mask.length; i++)
+ // mask[i] = diagA[i] + diagR[i] > 0 ? 1d : 0d;
+ for (int i = 0; i < m; i++) {
+ for (int j = 0; j < m; j++) {
+ tmp[i][j] -= colSums[j];
+
+ if (tmp[i][j] < 0 && i != j) // Don't set diag to 0
+ tmp[i][j] = 0;
+
+ tmp[i][j] *= (1 - damping);
+ A[i][j] = (A[i][j] * damping) - tmp[i][j];
+ }
+
+ mask[i] = A[i][i] + R[i][i] > 0 ? 1.0 : 0.0;
+ }
+ }
+
+
+ @Override
+ protected AffinityPropagation fit() {
+ synchronized (fitLock) {
+ if (null != labels)
+ return this;
+
+
+ // Init labels
+ final LogTimer timer = new LogTimer();
+ labels = new int[m];
+
+ /*
+ * All elements singular
+ */
+ if (this.singular_value) {
+ warn("algorithm converged immediately due to all elements being equal in input matrix");
+ this.converged = true;
+ this.fitSummary.add(new Object[]{
+ 0, converged, timer.formatTime(), timer.formatTime(), 1, timer.wallMsg()
+ });
+
+ sayBye(timer);
+ return this;
+ }
+
+
+ sim_mat = computeSmoothedSimilarity(data.getData(), getSeparabilityMetric(), getSeed(), addNoise);
+ info("computed similarity matrix and smoothed degeneracies in " + timer.toString());
+
+
+ // Affinity propagation uses two matrices: the responsibility
+ // matrix, R, and the availability matrix, A
+ double[][] A = new double[m][m];
+ double[][] R = new double[m][m];
+ double[][] tmp = new double[m][m]; // Intermediate staging...
+
+
+ // Begin here
+ int[] I = new int[m];
+ double[][] e = new double[m][iterBreak];
+ double[] Y; // vector of arg maxes
+ double[] Y2; // vector of maxes post neg inf
+ double[] sum_e;
+
+
+ final LogTimer iterTimer = new LogTimer();
+ info("beginning affinity computations " + timer.wallMsg());
+
+
+ long iterStart = Long.MAX_VALUE;
+ for (iterCt = 0; iterCt < maxIter; iterCt++) {
+ iterStart = iterTimer.now();
+
+ /*
+ * First piece in place
+ */
+ Y = new double[m];
+ Y2 = new double[m]; // Second max for each row
+ affinityPiece1(A, sim_mat, tmp, I, Y, Y2);
+
+
+ /*
+ * Second piece in place
+ */
+ final double[] columnSums = new double[m];
+ affinityPiece2(columnSums, tmp, I, sim_mat, R, Y, Y2, damping);
+
+
+ /*
+ * Third piece in place
+ */
+ final double[] mask = new double[m];
+ affinityPiece3(tmp, columnSums, A, R, mask, damping);
+
+
+ // Set the mask in `e`
+ MatUtils.setColumnInPlace(e, iterCt % iterBreak, mask);
+ numClusters = (int) VecUtils.sum(mask);
+
+
+ if (iterCt >= iterBreak) { // Time to check convergence criteria...
+ sum_e = MatUtils.rowSums(e);
+
+ // masking
+ int maskCt = 0;
+ for (int i = 0; i < sum_e.length; i++)
+ maskCt += sum_e[i] == 0 || sum_e[i] == iterBreak ? 1 : 0;
+
+ converged = maskCt == m;
+
+ if ((converged && numClusters > 0) || iterCt == maxIter) {
+ info("converged after " + (iterCt) + " iteration" + (iterCt != 1 ? "s" : "") +
+ " in " + iterTimer.toString());
+ break;
+ } // Else did not converge...
+ } // End outer if
+
+
+ fitSummary.add(new Object[]{
+ iterCt, converged,
+ iterTimer.formatTime(iterTimer.now() - iterStart),
+ timer.formatTime(),
+ numClusters,
+ timer.wallTime()
+ });
+ } // End for
+
+
+ if (!converged) warn("algorithm did not converge");
+ else { // needs one last info
+ fitSummary.add(new Object[]{
+ iterCt, converged,
+ iterTimer.formatTime(iterTimer.now() - iterStart),
+ timer.formatTime(),
+ numClusters,
+ timer.wallTime()
+ });
+ }
+
+
+ info("labeling clusters from availability and responsibility matrices");
+
+
+ // sklearn line: I = np.where(np.diag(A + R) > 0)[0]
+ final ArrayList arWhereOver0 = new ArrayList<>();
+
+ // Get diagonal of A + R and add to arWhereOver0 if > 0
+ // Could do this: MatUtils.diagFromSquare(MatUtils.add(A, R));
+ // But takes 3M time... this takes M
+ for (int i = 0; i < m; i++)
+ if (A[i][i] + R[i][i] > 0)
+ arWhereOver0.add(i);
+
+ // Reassign to array, so whole thing takes 1M + K rather than 3M + K
+ I = new int[arWhereOver0.size()];
+ for (int j = 0; j < I.length; j++) I[j] = arWhereOver0.get(j);
+
+
+ // Assign final K -- sklearn line: K = I.size # Identify exemplars
+ numClusters = I.length;
+ info(numClusters + " cluster" + (numClusters != 1 ? "s" : "") + " identified");
+
+
+ // Assign the labels
+ if (numClusters > 0) {
+
+ /*
+ * I holds the columns we want out of sim_mat,
+ * retrieve this cols, do a row-wise argmax to get 'c'
+ * sklearn line: c = np.argmax(S[:, I], axis=1)
+ */
+ double[][] over0cols = new double[m][numClusters];
+ int over_idx = 0;
+ for (int i : I)
+ MatUtils.setColumnInPlace(over0cols, over_idx++, MatUtils.getColumn(sim_mat, i));
+
+
+
+ /*
+ * Identify clusters
+ * sklearn line: c[I] = np.arange(K) # Identify clusters
+ */
+ int[] c = MatUtils.argMax(over0cols, Axis.ROW);
+ int k = 0;
+ for (int i : I)
+ c[i] = k++;
+
+
+ /* Refine the final set of exemplars and clusters and return results
+ * sklearn:
+ *
+ * for k in range(K):
+ * ii = np.where(c == k)[0]
+ * j = np.argmax(np.sum(S[ii[:, np.newaxis], ii], axis=0))
+ * I[k] = ii[j]
+ */
+ ArrayList ii = null;
+ int[] iii = null;
+ for (k = 0; k < numClusters; k++) {
+ // indices where c == k; sklearn line:
+ // ii = np.where(c == k)[0]
+ ii = new ArrayList();
+ for (int u = 0; u < c.length; u++)
+ if (c[u] == k)
+ ii.add(u);
+
+ // Big block to break down sklearn process
+ // overall sklearn line: j = np.argmax(np.sum(S[ii[:, np.newaxis], ii], axis=0))
+ iii = new int[ii.size()]; // convert to int array for MatUtils
+ for (int j = 0; j < iii.length; j++) iii[j] = ii.get(j);
+
+
+ // sklearn line: S[ii[:, np.newaxis], ii]
+ double[][] cube = MatUtils.getRows(MatUtils.getColumns(sim_mat, iii), iii);
+ double[] colSums = MatUtils.colSums(cube);
+ final int argMax = VecUtils.argMax(colSums);
+
+
+ // sklearn: I[k] = ii[j]
+ I[k] = iii[argMax];
+ }
+
+
+ // sklearn line: c = np.argmax(S[:, I], axis=1)
+ double[][] colCube = MatUtils.getColumns(sim_mat, I);
+ c = MatUtils.argMax(colCube, Axis.ROW);
+
+
+ // sklearn line: c[I] = np.arange(K)
+ for (int j = 0; j < I.length; j++) // I.length == K, == numClusters
+ c[I[j]] = j;
+
+
+ // sklearn line: labels = I[c]
+ for (int j = 0; j < m; j++)
+ labels[j] = I[c[j]];
+
+
+ /*
+ * Reduce labels to a sorted, gapless, list
+ * sklearn line: cluster_centers_indices = np.unique(labels)
+ */
+ centroidIndices = new ArrayList(numClusters);
+ for (Integer i : labels) // force autobox
+ if (!centroidIndices.contains(i)) // Not race condition because synchronized
+ centroidIndices.add(i);
+
+ /*
+ * final label assignment...
+ * sklearn line: labels = np.searchsorted(cluster_centers_indices, labels)
+ */
+ for (int i = 0; i < labels.length; i++)
+ labels[i] = centroidIndices.indexOf(labels[i]);
+
+ /*
+ * Don't forget to assign the centroids!
+ */
+ this.centroids = new ArrayList<>();
+ for (Integer idx : centroidIndices) {
+ this.centroids.add(this.data.getRow(idx));
+ }
+ } else {
+ centroids = new ArrayList<>(); // Empty
+ centroidIndices = new ArrayList<>(); // Empty
+ for (int i = 0; i < m; i++)
+ labels[i] = -1; // Missing
+ }
+
+
+ // Clean up
+ sim_mat = null;
+
+ // Since cachedA/R are volatile, it's more expensive to make potentially hundreds(+)
+ // of writes to a volatile class member. To save this time, reassign A/R only once.
+ cachedA = A;
+ cachedR = R;
+
+ sayBye(timer);
+
+ return this;
+ }
+
+ } // End fit
+
+ @Override
+ public int getNumberOfIdentifiedClusters() {
+ return numClusters;
+ }
+
+ @Override
+ final protected Object[] getModelFitSummaryHeaders() {
+ return new Object[]{
+ "Iter. #", "Converged", "Iter. Time", "Tot. Time", "Num Clusters", "Wall"
+ };
+ }
+
+ @Override
+ public ArrayList getCentroids() {
+ if (null == centroids)
+ error(new ModelNotFitException("model has not yet been fit"));
+
+ final ArrayList cent = new ArrayList();
+ for (double[] d : centroids)
+ cent.add(VecUtils.copy(d));
+
+ return cent;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public int[] predict(RealMatrix newData) {
+ return CentroidUtils.predict(this, newData);
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/AffinityPropagationParameters.java b/src/main/java/com/clust4j/algo/AffinityPropagationParameters.java
new file mode 100644
index 00000000..9373a358
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/AffinityPropagationParameters.java
@@ -0,0 +1,111 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import java.util.Random;
+
+/**
+ * A model setup class for {@link AffinityPropagation}. This class houses all
+ * of the hyper-parameter settings to build an {@link AffinityPropagation} instance
+ * using the {@link #fitNewModel(RealMatrix)} method.
+ *
+ * @author Taylor G Smith
+ */
+public class AffinityPropagationParameters
+ extends BaseClustererParameters
+ implements UnsupervisedClassifierParameters {
+
+ private static final long serialVersionUID = -6096855634412545959L;
+ protected int maxIter = AffinityPropagation.DEF_MAX_ITER;
+ protected double minChange = AffinityPropagation.DEF_TOL;
+ protected int iterBreak = AffinityPropagation.DEF_ITER_BREAK;
+ protected double damping = AffinityPropagation.DEF_DAMPING;
+ protected boolean addNoise = AffinityPropagation.DEF_ADD_GAUSSIAN_NOISE;
+
+ public AffinityPropagationParameters() { /* Default constructor */ }
+
+ public AffinityPropagationParameters useGaussianSmoothing(boolean b) {
+ this.addNoise = b;
+ return this;
+ }
+
+ @Override
+ public AffinityPropagation fitNewModel(RealMatrix data) {
+ return new AffinityPropagation(data, this.copy()).fit();
+ }
+
+ @Override
+ public AffinityPropagationParameters copy() {
+ return new AffinityPropagationParameters()
+ .setDampingFactor(damping)
+ .setIterBreak(iterBreak)
+ .setMaxIter(maxIter)
+ .setMinChange(minChange)
+ .setSeed(seed)
+ .setMetric(metric)
+ .setVerbose(verbose)
+ .useGaussianSmoothing(addNoise)
+ .setForceParallel(parallel);
+ }
+
+ public AffinityPropagationParameters setDampingFactor(final double damp) {
+ this.damping = damp;
+ return this;
+ }
+
+ public AffinityPropagationParameters setIterBreak(final int iters) {
+ this.iterBreak = iters;
+ return this;
+ }
+
+ public AffinityPropagationParameters setMaxIter(final int max) {
+ this.maxIter = max;
+ return this;
+ }
+
+ public AffinityPropagationParameters setMinChange(final double min) {
+ this.minChange = min;
+ return this;
+ }
+
+ @Override
+ public AffinityPropagationParameters setSeed(Random rand) {
+ seed = rand;
+ return this;
+ }
+
+ @Override
+ public AffinityPropagationParameters setForceParallel(boolean b) {
+ this.parallel = b;
+ return this;
+ }
+
+ @Override
+ public AffinityPropagationParameters setVerbose(boolean b) {
+ verbose = b;
+ return this;
+ }
+
+ @Override
+ public AffinityPropagationParameters setMetric(GeometricallySeparable dist) {
+ this.metric = dist;
+ return this;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/BallTree.java b/src/main/java/com/clust4j/algo/BallTree.java
new file mode 100644
index 00000000..27a06bf4
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/BallTree.java
@@ -0,0 +1,212 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.log.Loggable;
+import com.clust4j.metrics.pairwise.Distance;
+import com.clust4j.metrics.pairwise.DistanceMetric;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.metrics.pairwise.MinkowskiDistance;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.HashSet;
+
+/**
+ * In computer science, a ball tree, balltree or metric tree,
+ * is a space partitioning data structure for organizing points
+ * in a multi-dimensional space. The ball tree gets its name from the
+ * fact that it partitions data points into a nested set of hyperspheres
+ * known as "balls". The resulting data structure has characteristics
+ * that make it useful for a number of applications, most notably
+ * nearest neighbor search.
+ *
+ * @author Taylor G Smith
+ * @see NearestNeighborHeapSearch
+ * @see Ball tree
+ */
+public class BallTree extends NearestNeighborHeapSearch {
+ private static final long serialVersionUID = -6424085914337479234L;
+ public final static HashSet> VALID_METRICS;
+
+ static {
+ VALID_METRICS = new HashSet<>();
+
+ /*
+ * Want all distance metrics EXCEPT binary dist metrics
+ * and Canberra -- it tends to behave oddly on non-normalized data
+ */
+ for (Distance dm : Distance.values()) {
+ if (!dm.isBinaryDistance() && !dm.equals(Distance.CANBERRA)) {
+ VALID_METRICS.add(dm.getClass());
+ }
+ }
+
+ VALID_METRICS.add(MinkowskiDistance.class);
+ VALID_METRICS.add(Distance.HAVERSINE.MI.getClass());
+ VALID_METRICS.add(Distance.HAVERSINE.KM.getClass());
+ }
+
+
+ @Override
+ protected boolean checkValidDistMet(GeometricallySeparable dist) {
+ return VALID_METRICS.contains(dist.getClass());
+ }
+
+
+ public BallTree(final RealMatrix X) {
+ super(X);
+ }
+
+ public BallTree(final RealMatrix X, int leaf_size) {
+ super(X, leaf_size);
+ }
+
+ public BallTree(final RealMatrix X, DistanceMetric dist) {
+ super(X, dist);
+ }
+
+ public BallTree(final RealMatrix X, Loggable logger) {
+ super(X, logger);
+ }
+
+ public BallTree(final RealMatrix X, int leaf_size, DistanceMetric dist) {
+ super(X, leaf_size, dist);
+ }
+
+ public BallTree(final RealMatrix X, int leaf_size, DistanceMetric dist, Loggable logger) {
+ super(X, leaf_size, dist, logger);
+ }
+
+ /**
+ * Constructor with logger and distance metric
+ *
+ * @param X
+ * @param dist
+ * @param logger
+ */
+ public BallTree(final RealMatrix X, DistanceMetric dist, Loggable logger) {
+ super(X, dist, logger);
+ }
+
+ protected BallTree(final double[][] X, int leaf_size, DistanceMetric dist, Loggable logger) {
+ super(X, leaf_size, dist, logger);
+ }
+
+
+ @Override
+ void allocateData(NearestNeighborHeapSearch tree, int n_nodes, int n_features) {
+ tree.node_bounds = new double[1][n_nodes][n_features];
+ }
+
+ @Override
+ void initNode(NearestNeighborHeapSearch tree, int i_node, int idx_start, int idx_end) {
+ int n_points = idx_end - idx_start, i, j, n_features = tree.N_FEATURES;
+ double radius = 0;
+ int[] idx_array = tree.idx_array;
+ double[][] data = tree.data_arr;
+ double[] centroid = tree.node_bounds[0][i_node], this_pt;
+
+ // Determine centroid
+ for (j = 0; j < n_features; j++)
+ centroid[j] = 0;
+
+ for (i = idx_start; i < idx_end; i++) {
+ this_pt = data[idx_array[i]];
+
+ for (j = 0; j < n_features; j++)
+ centroid[j] += this_pt[j];
+ }
+
+ // Update centroids
+ for (j = 0; j < n_features; j++)
+ centroid[j] /= n_points;
+
+
+ // determine node radius
+ for (i = idx_start; i < idx_end; i++)
+ radius = FastMath.max(radius,
+ tree.rDist(centroid, data[idx_array[i]]));
+
+ tree.node_data[i_node].radius = tree.dist_metric.partialDistanceToDistance(radius);
+ tree.node_data[i_node].idx_start = idx_start;
+ tree.node_data[i_node].idx_end = idx_end;
+ }
+
+ @Override
+ final BallTree newInstance(double[][] arr, int leaf, DistanceMetric dist, Loggable logger) {
+ return new BallTree(new Array2DRowRealMatrix(arr, false), leaf, dist, logger);
+ }
+
+ @Override
+ double minDist(NearestNeighborHeapSearch tree, int i_node, double[] pt) {
+ double dist_pt = tree.dist(pt, tree.node_bounds[0][i_node]);
+ return FastMath.max(0, dist_pt - tree.node_data[i_node].radius);
+ }
+
+ @Override
+ double minRDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2, int iNode2) {
+ return tree1.dist_metric.distanceToPartialDistance(minDistDual(tree1, iNode1, tree2, iNode2));
+ }
+
+ @Override
+ double minRDist(NearestNeighborHeapSearch tree, int i_node, double[] pt) {
+ return tree.dist_metric.distanceToPartialDistance(minDist(tree, i_node, pt));
+ }
+
+ /*
+ @Override
+ double maxDist(NearestNeighborHeapSearch tree, int i_node, double[] pt) {
+ double dist_pt = tree.dist(pt, tree.node_bounds[0][i_node]);
+ return dist_pt + tree.node_data[i_node].radius;
+ }
+
+ @Override
+ double maxRDist(NearestNeighborHeapSearch tree, int i_node, double[] pt) {
+ return tree.dist_metric.distanceToPartialDistance(maxDist(tree, i_node, pt));
+ }
+ */
+
+ @Override
+ double maxRDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2, int iNode2) {
+ return tree1.dist_metric.distanceToPartialDistance(maxDistDual(tree1, iNode1, tree2, iNode2));
+ }
+
+ @Override
+ double maxDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2, int iNode2) {
+ double dist_pt = tree1.dist(tree2.node_bounds[0][iNode2], tree1.node_bounds[0][iNode1]);
+ return dist_pt + tree1.node_data[iNode1].radius + tree2.node_data[iNode2].radius;
+ }
+
+ @Override
+ double minDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2, int iNode2) {
+ double dist_pt = tree1.dist(tree2.node_bounds[0][iNode2],
+ tree1.node_bounds[0][iNode1]);
+ return FastMath.max(0,
+ (dist_pt
+ - tree1.node_data[iNode1].radius
+ - tree2.node_data[iNode2].radius));
+ }
+
+ @Override
+ void minMaxDist(NearestNeighborHeapSearch tree, int i_node, double[] pt, MutableDouble minDist, MutableDouble maxDist) {
+ double dist_pt = tree.dist(pt, tree.node_bounds[0][i_node]);
+ double rad = tree.node_data[i_node].radius;
+ minDist.value = FastMath.max(0, dist_pt - rad);
+ maxDist.value = dist_pt + rad;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/BaseClassifier.java b/src/main/java/com/clust4j/algo/BaseClassifier.java
new file mode 100644
index 00000000..88d88407
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/BaseClassifier.java
@@ -0,0 +1,49 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.metrics.scoring.SupervisedMetric;
+import com.clust4j.metrics.scoring.UnsupervisedMetric;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import static com.clust4j.metrics.scoring.UnsupervisedMetric.SILHOUETTE;
+
+/**
+ * An interface for classifiers, both supervised and unsupervised.
+ *
+ * @author Taylor G Smith
+ */
+public interface BaseClassifier extends java.io.Serializable {
+ public final static SupervisedMetric DEF_SUPERVISED_METRIC = SupervisedMetric.BINOMIAL_ACCURACY;
+ public final static UnsupervisedMetric DEF_UNSUPERVISED_METRIC = SILHOUETTE;
+
+ /**
+ * Returns a copy of the assigned class labels in
+ * record order
+ *
+ * @return
+ */
+ public int[] getLabels();
+
+ /**
+ * Predict on new data
+ *
+ * @param newData
+ * @return
+ * @throws ModelNotFitException if the model hasn't yet been fit
+ */
+ public int[] predict(RealMatrix newData);
+}
diff --git a/src/main/java/com/clust4j/algo/BaseClassifierParameters.java b/src/main/java/com/clust4j/algo/BaseClassifierParameters.java
new file mode 100644
index 00000000..2c4560ab
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/BaseClassifierParameters.java
@@ -0,0 +1,23 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.utils.DeepCloneable;
+
+public interface BaseClassifierParameters extends DeepCloneable {
+ @Override
+ public BaseClassifierParameters copy();
+}
diff --git a/src/main/java/com/clust4j/algo/BaseClustererParameters.java b/src/main/java/com/clust4j/algo/BaseClustererParameters.java
new file mode 100644
index 00000000..88baa098
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/BaseClustererParameters.java
@@ -0,0 +1,69 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.Clust4j;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.utils.DeepCloneable;
+
+import java.util.Random;
+
+/**
+ * Base planner class many clustering algorithms
+ * will extend with static inner classes. Some clustering
+ * algorithms will require more parameters and must provide
+ * the interface for the getting/setting of such parameters.
+ *
+ * @author Taylor G Smith
+ */
+abstract public class BaseClustererParameters
+ extends Clust4j // So all are serializable
+ implements DeepCloneable, BaseClassifierParameters {
+ private static final long serialVersionUID = -5830795881133834268L;
+
+ protected boolean parallel,
+ verbose = AbstractClusterer.DEF_VERBOSE;
+ protected Random seed = AbstractClusterer.DEF_SEED;
+ protected GeometricallySeparable metric = AbstractClusterer.DEF_DIST;
+
+ @Override
+ abstract public BaseClustererParameters copy();
+
+ abstract public BaseClustererParameters setSeed(final Random rand);
+
+ abstract public BaseClustererParameters setVerbose(final boolean b);
+
+ abstract public BaseClustererParameters setMetric(final GeometricallySeparable dist);
+
+ abstract public BaseClustererParameters setForceParallel(final boolean b);
+
+ final public GeometricallySeparable getMetric() {
+ return metric;
+ }
+
+ final public boolean getParallel() {
+ return parallel;
+ }
+
+ final public Random getSeed() {
+ return seed;
+ }
+
+ final public boolean getVerbose() {
+ return verbose;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/BaseModel.java b/src/main/java/com/clust4j/algo/BaseModel.java
new file mode 100644
index 00000000..48d950ed
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/BaseModel.java
@@ -0,0 +1,46 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.Clust4j;
+import com.clust4j.utils.SynchronicityLock;
+import com.clust4j.utils.TableFormatter;
+
+import java.text.NumberFormat;
+
+abstract public class BaseModel extends Clust4j implements java.io.Serializable {
+ private static final long serialVersionUID = 4707757741169405063L;
+ public final static TableFormatter formatter;
+
+ // Initializers
+ static {
+ NumberFormat nf = NumberFormat.getInstance(TableFormatter.DEFAULT_LOCALE);
+ nf.setMaximumFractionDigits(5);
+ formatter = new TableFormatter(nf);
+ formatter.leadWithEmpty = false;
+ formatter.setWhiteSpace(1);
+ }
+
+ /**
+ * The lock to synchronize on for fits
+ */
+ protected final Object fitLock = new SynchronicityLock();
+
+ /**
+ * This should be synchronized and thread-safe
+ */
+ protected abstract BaseModel fit();
+}
diff --git a/src/main/java/com/clust4j/algo/BaseNeighborsModel.java b/src/main/java/com/clust4j/algo/BaseNeighborsModel.java
new file mode 100644
index 00000000..ee395e09
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/BaseNeighborsModel.java
@@ -0,0 +1,283 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.GlobalState;
+import com.clust4j.except.ModelNotFitException;
+import com.clust4j.metrics.pairwise.DistanceMetric;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import org.apache.commons.math3.linear.RealMatrix;
+
+abstract public class BaseNeighborsModel extends AbstractClusterer {
+ private static final long serialVersionUID = 1054047329248586585L;
+
+ public static final NeighborsAlgorithm DEF_ALGO = NeighborsAlgorithm.AUTO;
+ public static final int DEF_LEAF_SIZE = 30;
+ public static final int DEF_K = 5;
+ public static final double DEF_RADIUS = 5.0;
+ public final static boolean DUAL_TREE_SEARCH = false;
+ public final static boolean SORT = true;
+
+ protected Integer kNeighbors = null;
+ protected Double radius = null;
+ protected boolean radiusMode;
+ protected int leafSize, m;
+ protected double[][] fit_X;
+ protected NearestNeighborHeapSearch tree;
+ protected NeighborsAlgorithm alg;
+
+ /**
+ * Resultant neighborhood from fit method
+ */
+ protected volatile Neighborhood res;
+
+ interface TreeBuilder extends MetricValidator {
+ public NearestNeighborHeapSearch buildTree(RealMatrix data,
+ int leafSize, BaseNeighborsModel logger);
+ }
+
+ public static enum NeighborsAlgorithm implements TreeBuilder {
+ AUTO {
+ @Override
+ public NearestNeighborHeapSearch buildTree(RealMatrix data,
+ int leafSize, BaseNeighborsModel logger) {
+
+ NeighborsAlgorithm alg = delegateAlgorithm(data);
+ return alg.buildTree(data, leafSize, logger);
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable geo) {
+ throw new UnsupportedOperationException("auto has no metric validity criteria");
+ }
+
+ },
+
+ KD_TREE {
+ @Override
+ public NearestNeighborHeapSearch buildTree(RealMatrix data,
+ int leafSize, BaseNeighborsModel logger) {
+ logger.alg = this;
+ return new KDTree(data, leafSize, handleMetric(this, logger), logger);
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable g) {
+ return KDTree.VALID_METRICS.contains(g.getClass());
+ }
+ },
+
+ BALL_TREE {
+ @Override
+ public NearestNeighborHeapSearch buildTree(RealMatrix data,
+ int leafSize, BaseNeighborsModel logger) {
+ logger.alg = this;
+ return new BallTree(data, leafSize, handleMetric(this, logger), logger);
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable g) {
+ return BallTree.VALID_METRICS.contains(g.getClass());
+ }
+ };
+
+ private static NeighborsAlgorithm delegateAlgorithm(RealMatrix arm) {
+ int mn = arm.getColumnDimension() * arm.getRowDimension();
+ return mn > GlobalState.ParallelismConf.MIN_ELEMENTS ?
+ BALL_TREE : KD_TREE;
+ }
+
+ private static DistanceMetric handleMetric(NeighborsAlgorithm na, BaseNeighborsModel logger) {
+ GeometricallySeparable g = logger.dist_metric;
+ if (!na.isValidMetric(g)) {
+ logger.warn(g.getName() + " is not a valid metric for " + na + ". "
+ + "Falling back to default Euclidean");
+ logger.setSeparabilityMetric(DEF_DIST);
+ }
+
+ return (DistanceMetric) logger.dist_metric;
+ }
+ }
+
+ @Override
+ final public boolean isValidMetric(GeometricallySeparable g) {
+ return this.alg.isValidMetric(g);
+ }
+
+
+ protected BaseNeighborsModel(AbstractClusterer caller, BaseNeighborsPlanner extends BaseNeighborsModel> planner) {
+ super(caller, planner);
+ init(planner);
+ }
+
+ protected BaseNeighborsModel(RealMatrix data, BaseNeighborsPlanner extends BaseNeighborsModel> planner, boolean as_is) {
+ super(data, planner, as_is);
+ init(planner);
+ }
+
+ protected BaseNeighborsModel(RealMatrix data, BaseNeighborsPlanner extends BaseNeighborsModel> planner) {
+ super(data, planner);
+ init(planner);
+ }
+
+ final private void init(BaseNeighborsPlanner extends BaseNeighborsModel> planner) {
+ this.kNeighbors = planner.getK();
+ this.radius = planner.getRadius();
+ this.leafSize = planner.getLeafSize();
+
+ radiusMode = null != radius;
+
+ /*
+ if(!(planner.getSep() instanceof DistanceMetric)) {
+ warn(planner.getSep() + " not a valid metric for neighbors models. "
+ + "Falling back to default: " + DEF_DIST);
+ super.setSeparabilityMetric(DEF_DIST);
+ }
+ */
+
+ if (leafSize < 1)
+ throw new IllegalArgumentException("leafsize must be positive");
+
+ /*
+ * Internally handles metric validation...
+ */
+ this.tree = planner.getAlgorithm().buildTree(this.data, this.leafSize, this);
+
+ // Get the data ref from the tree
+ fit_X = tree.getData();
+ this.m = fit_X.length;
+ }
+
+ abstract public static class BaseNeighborsPlanner
+ extends BaseClustererParameters
+ implements NeighborsClassifierParameters {
+ private static final long serialVersionUID = 8356804193088162871L;
+
+ protected int leafSize = DEF_LEAF_SIZE;
+ protected NeighborsAlgorithm algo = DEF_ALGO;
+
+ @Override
+ abstract public T fitNewModel(RealMatrix d);
+
+ abstract public BaseNeighborsPlanner setAlgorithm(NeighborsAlgorithm algo);
+
+ abstract public Integer getK();
+
+ abstract public Double getRadius();
+
+ final public int getLeafSize() {
+ return leafSize;
+ }
+
+ final public NeighborsAlgorithm getAlgorithm() {
+ return algo;
+ }
+ }
+
+ public Neighborhood getNeighbors() {
+ if (null == res)
+ throw new ModelNotFitException("model not yet fit");
+ return res.copy();
+ }
+
+ /**
+ * A class to query the tree for neighborhoods in parallel
+ *
+ * @author Taylor G Smith
+ */
+ abstract static class ParallelNeighborhoodSearch extends ParallelChunkingTask {
+ private static final long serialVersionUID = -1600812794470325448L;
+
+ final BaseNeighborsModel model;
+ final double[][] distances;
+ final int[][] indices;
+ final int lo;
+ final int hi;
+
+ public ParallelNeighborhoodSearch(double[][] X, BaseNeighborsModel model) {
+ super(X); // this auto-chunks the data
+
+ this.model = model;
+ this.lo = 0;
+ this.hi = strategy.getNumChunks(X);
+
+ /*
+ * First get the length...
+ */
+ int length = 0;
+ for (Chunk c : this.chunks)
+ length += c.size();
+
+ this.distances = new double[length][];
+ this.indices = new int[length][];
+ }
+
+ public ParallelNeighborhoodSearch(ParallelNeighborhoodSearch task, int lo, int hi) {
+ super(task);
+
+ this.model = task.model;
+ this.lo = lo;
+ this.hi = hi;
+ this.distances = task.distances;
+ this.indices = task.indices;
+ }
+
+ @Override
+ public Neighborhood reduce(Chunk chunk) {
+ Neighborhood n = query(model.tree, chunk.get());
+
+ // assign to low index, since that's how we retrieved the chunk...
+ final int start = chunk.start, end = start + chunk.size();
+ double[][] d = n.getDistances();
+ int[][] i = n.getIndices();
+
+ // Set the distances and indices in place...
+ for (int j = start, idx = 0; j < end; j++, idx++) {
+ this.distances[j] = d[idx];
+ this.indices[j] = i[idx];
+ }
+
+ return n;
+ }
+
+ @Override
+ protected Neighborhood compute() {
+ if (hi - lo <= 1) { // generally should equal one...
+ return reduce(chunks.get(lo));
+ } else {
+ int mid = this.lo + (this.hi - this.lo) / 2;
+ ParallelNeighborhoodSearch left = newInstance(this, this.lo, mid);
+ ParallelNeighborhoodSearch right = newInstance(this, mid, this.hi);
+
+ left.fork();
+ right.compute();
+ left.join();
+
+ return new Neighborhood(distances, indices);
+ }
+ }
+
+ abstract ParallelNeighborhoodSearch newInstance(ParallelNeighborhoodSearch p, int lo, int hi);
+
+ abstract Neighborhood query(NearestNeighborHeapSearch tree, double[][] X);
+ }
+
+
+ abstract Neighborhood getNeighbors(RealMatrix matrix);
+
+ @Override
+ abstract protected BaseNeighborsModel fit();
+}
diff --git a/src/main/java/com/clust4j/algo/BoruvkaAlgorithm.java b/src/main/java/com/clust4j/algo/BoruvkaAlgorithm.java
new file mode 100644
index 00000000..fce60903
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/BoruvkaAlgorithm.java
@@ -0,0 +1,765 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.algo.NearestNeighborHeapSearch.NodeData;
+import com.clust4j.log.LogTimer;
+import com.clust4j.log.Loggable;
+import com.clust4j.metrics.pairwise.DistanceMetric;
+import com.clust4j.metrics.pairwise.Pairwise;
+import com.clust4j.utils.VecUtils;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+/**
+ * A graph traversal algorithm used in identifying the minimum spanning tree
+ * in a graph for which all edge weights are distinct. Used in conjunction with
+ * {@link HDBSCAN}, and adapted from the HDBSCAN python package.
+ *
+ * @author Taylor G Smith
+ * @see Boruvka's algorithm
+ */
+class BoruvkaAlgorithm implements java.io.Serializable {
+ private static final long serialVersionUID = 3935595821188876442L;
+
+ // the initialization reorganizes the trees
+ final protected Boruvka alg;
+
+ private final NearestNeighborHeapSearch outer_tree;
+ private final int minSamples;
+ private final DistanceMetric metric;
+ private final boolean approxMinSpanTree;
+ private final int leafSize;
+ private final Loggable logger;
+ private final double alpha;
+
+ protected BoruvkaAlgorithm(NearestNeighborHeapSearch tree, int min_samples,
+ DistanceMetric metric, int leafSize, boolean approx_min_span_tree,
+ double alpha, Loggable logger) {
+
+ this.outer_tree = tree;
+ this.minSamples = min_samples;
+ this.metric = metric;
+ this.leafSize = leafSize;
+ this.approxMinSpanTree = approx_min_span_tree;
+ this.alpha = alpha;
+ this.logger = logger;
+
+
+ // Create the actual solver -- if using logger,
+ // updates with info in the actual algorithm
+ alg = (tree instanceof KDTree) ?
+ new KDTreeBoruvAlg() :
+ new BallTreeBoruvAlg();
+ }
+
+
+ protected static class BoruvkaUnionFind extends HDBSCAN.TreeUnionFind {
+ BoruvkaUnionFind(int N) {
+ super(N);
+ }
+ }
+
+
+ protected static double ballTreeMinDistDual(double rad1, double rad2, int node1, int node2, double[][] centroidDist) {
+ double distPt = centroidDist[node1][node2];
+ return FastMath.max(0, (distPt - rad1 - rad2));
+ }
+
+ /*
+ * Similar to {@link KDTree}.minRDistDual(...) but
+ * uses one node bounds array instead of two instances of
+ * {@link NearestNeighborHeapSearch}
+ * @param metric
+ * @param node1
+ * @param node2
+ * @param nodeBounds
+ * @param n
+ * @return
+ *
+ static double kdTreeMinDistDual(DistanceMetric metric, int node1, int node2, double[][][] nodeBounds, int n) {
+ return metric.partialDistanceToDistance(kdTreeMinRDistDual(metric, node1, node2, nodeBounds, n));
+ }
+ */
+
+ protected static double kdTreeMinRDistDual(DistanceMetric metric, int node1, int node2, double[][][] nodeBounds, int n) {
+ double d, d1, d2, rdist = 0.0;
+ boolean inf = metric.getP() == Double.POSITIVE_INFINITY;
+ int j;
+
+ for (j = 0; j < n; j++) {
+ d1 = nodeBounds[0][node1][j] - nodeBounds[1][node2][j];
+ d2 = nodeBounds[0][node2][j] - nodeBounds[1][node1][j];
+ d = (d1 + FastMath.abs(d1)) + (d2 + FastMath.abs(d2));
+
+ rdist =
+ inf ? FastMath.max(rdist, 0.5 * d) :
+ rdist + FastMath.pow(0.5 * d, metric.getP());
+ }
+
+ return rdist;
+ }
+
+
+ /**
+ * The {@link NearestNeighborHeapSearch}
+ * tree traversal algorithm
+ *
+ * @author Taylor G Smith
+ */
+ protected abstract class Boruvka {
+ final static int INIT_VAL = -1;
+
+ final NearestNeighborHeapSearch coreDistTree = outer_tree;
+ final NearestNeighborHeapSearch TREE;
+ final BoruvkaUnionFind componentUnionFind;
+
+ final double[][] tree_data_ref;
+ final double[][][] node_bounds;
+ final int[] idx_array;
+ final NodeData[] node_data_ref;
+ final boolean partialDistTransform;
+
+ int numPoints, numFeatures,
+ numNodes, numEdges;
+ double[] bounds;
+ int[] components,
+ componentOfPoint,
+ componentOfNode,
+ candidateNeighbors,
+ candidatePoint;
+ double[] candidateDistance;
+ double[][] edges;
+ double[] coreDistance;
+
+ Boruvka(boolean partialTrans, NearestNeighborHeapSearch TREE) {
+ this.TREE = TREE;
+ this.tree_data_ref = TREE.getDataRef();
+ this.node_bounds = TREE.getNodeBoundsRef();
+ this.idx_array = TREE.getIndexArrayRef();
+ this.node_data_ref = TREE.getNodeDataRef();
+
+ this.numPoints = this.tree_data_ref.length;
+ this.numFeatures = this.tree_data_ref[0].length;
+ this.numNodes = this.node_data_ref.length;
+
+ this.components = VecUtils.arange(numPoints);
+ this.bounds = new double[numNodes];
+ this.componentOfPoint = new int[numPoints];
+ this.componentOfNode = new int[numNodes];
+ this.candidateNeighbors = new int[numPoints];
+ this.candidatePoint = new int[numPoints];
+ this.candidateDistance = new double[numPoints];
+ this.edges = new double[numPoints - 1][3];
+ this.componentUnionFind = new BoruvkaUnionFind(numPoints);
+
+ LogTimer s = new LogTimer();
+ this.partialDistTransform = partialTrans;
+
+ initComponents();
+ computeBounds();
+
+ if (null != logger)
+ logger.info("completed Boruvka nearest neighbor search in " + s.toString());
+ }
+
+ final void initComponents() {
+ int n;
+
+ for (n = 0; n < this.numPoints; n++) {
+ this.componentOfPoint[n] = n;
+ this.candidateNeighbors[n] = INIT_VAL;
+ this.candidatePoint[n] = INIT_VAL;
+ this.candidateDistance[n] = Double.MAX_VALUE;
+ }
+
+ for (n = 0; n < numNodes; n++)
+ this.componentOfNode[n] = -(n + 1);
+ }
+
+ final double[][] spanningTree() {
+ int numComponents = this.tree_data_ref.length;
+
+ while (numComponents > 1) {
+ this.dualTreeTraversal(0, 0);
+ numComponents = this.updateComponents();
+ }
+
+ return this.edges;
+ }
+
+ final int updateComponents() {
+ int source, sink, c, component, n, i, p, currentComponent,
+ currentSrcComponent, currentSinkComponent, child1, child2,
+ lastNumComponents;
+ NodeData nodeInfo;
+
+ // For each component there should be a:
+ // - candidate point (a point in the component)
+ // - candidate neighbor (the point to join with)
+ // - candidate_distance (the distance from point to neighbor)
+ //
+ // We will go through and and an edge to the edge list
+ // for each of these, and the union the two points
+ // together in the union find structure
+ for (c = 0; c < this.components.length; c++ /* <- tee-hee */) {
+ component = this.components[c];
+ source = this.candidatePoint[component];
+ sink = this.candidateNeighbors[component];
+
+ //Src or sink is undefined...
+ if (source == INIT_VAL || sink == INIT_VAL)
+ continue;
+
+ currentSrcComponent = this.componentUnionFind.find(source);
+ currentSinkComponent = this.componentUnionFind.find(sink);
+
+
+ // Already joined these so ignore this edge
+ if (currentSrcComponent == currentSinkComponent) {
+ this.candidatePoint[component] = INIT_VAL;
+ this.candidateNeighbors[component] = INIT_VAL;
+ this.candidateDistance[component] = Double.MAX_VALUE;
+ continue;
+ }
+
+ // Set edge
+ this.edges[numEdges][0] = source;
+ this.edges[numEdges][1] = sink;
+ this.edges[numEdges][2] = this.partialDistTransform ?
+ metric.partialDistanceToDistance(
+ this.candidateDistance[component]) :
+ this.candidateDistance[component];
+ this.numEdges++;
+
+ // Join
+ this.componentUnionFind.union(source, sink);
+
+ // Reset everything and check for termination condition
+ this.candidateDistance[component] = Double.MAX_VALUE;
+ if (this.numEdges == this.numPoints - 1) {
+ this.components = this.componentUnionFind.components();
+ return components.length;
+ }
+ }
+
+
+ // After joining everything, we go through to determine
+ // the components of each point for an easier lookup. Makes
+ // for faster pruning later...
+ for (n = 0; n < this.tree_data_ref.length; n++)
+ this.componentOfPoint[n] = this.componentUnionFind.find(n);
+
+
+ for (n = this.node_data_ref.length - 1; n >= 0; n--) {
+ nodeInfo = this.node_data_ref[n];
+
+ // If node is leaf, check that every point in node is same component
+ if (nodeInfo.isLeaf()) {
+ currentComponent = this.componentOfPoint[idx_array[nodeInfo.start()]];
+
+ boolean found = false;
+ for (i = nodeInfo.start() + 1; i < nodeInfo.end(); i++) {
+ p = idx_array[i];
+ if (componentOfPoint[p] != currentComponent) {
+ found = true;
+ break;
+ }
+ }
+
+ // Alternative to the python for... else construct.
+ if (!found)
+ this.componentOfNode[n] = currentComponent;
+ }
+
+ // If not leaf, check both child nodes are same component
+ else {
+ child1 = 2 * n + 1;
+ child2 = 2 * n + 2;
+
+ if (this.componentOfNode[child1] == this.componentOfNode[child2])
+ this.componentOfNode[n] = this.componentOfNode[child1];
+ }
+ }
+
+
+ // This is a tie breaking method
+ if (approxMinSpanTree) {
+ lastNumComponents = this.components.length;
+ components = this.componentUnionFind.components();
+
+ if (components.length == lastNumComponents) // i.e., if all is isComponents are true
+ for (n = 0; n < numNodes; n++) // Reset
+ bounds[n] = Double.MAX_VALUE;
+
+ } else {
+ this.components = this.componentUnionFind.components();
+ for (n = 0; n < numNodes; n++)
+ this.bounds[n] = Double.MAX_VALUE;
+ }
+
+ return components.length;
+ }
+
+ abstract void computeBounds();
+
+ abstract int dualTreeTraversal(int node1, int node2);
+ }
+
+ protected class KDTreeBoruvAlg extends Boruvka {
+ KDTreeBoruvAlg() {
+ super(true, new KDTree(
+ new Array2DRowRealMatrix(outer_tree.getDataRef(), false),
+ leafSize, metric, logger));
+ }
+
+ @Override
+ void computeBounds() {
+ int n, i, m;
+
+ // The python code uses the breadth-first search, but
+ // we eliminated the breadth-first option in favor of depth-first
+ // for all cases for the time being.
+ Neighborhood queryResult =
+ TREE.query(tree_data_ref, minSamples + 1, true, true);
+
+ double[][] knnDist = queryResult.getDistances();
+ int[][] knnIndices = queryResult.getIndices();
+
+ // Assign the core distance array and change to rdist...
+ this.coreDistance = new double[knnDist.length];
+ for (i = 0; i < coreDistance.length; i++)
+ coreDistance[i] = metric
+ .distanceToPartialDistance(
+ knnDist[i][minSamples]);
+
+ for (n = 0; n < numPoints; n++) {
+ for (i = 1; i < minSamples + 1; i++) {
+ m = knnIndices[n][i];
+
+ if (this.coreDistance[m] <= this.coreDistance[n]) {
+ this.candidatePoint[n] = n;
+ this.candidateNeighbors[n] = m;
+ this.candidateDistance[n] = this.coreDistance[n];
+ break;
+ }
+ }
+ }
+
+ this.updateComponents();
+ for (n = 0; n < numNodes; n++)
+ this.bounds[n] = Double.MAX_VALUE;
+ }
+
+ @Override
+ int dualTreeTraversal(int node1, int node2) {
+ int[] pointIndices1, pointIndices2;
+ int i, j, p, q, parent;
+
+ double nodeDist, d, mrDist, newBound,
+ newUpperBound, newLowerBound,
+ leftDist, rightDist;
+
+ NodeData node1Info = node_data_ref[node1],
+ node2Info = node_data_ref[node2];
+
+ int component1, component2, left, right;
+
+ // Distance btwn query and ref nodes
+ nodeDist = kdTreeMinRDistDual(metric, node1, node2,
+ this.node_bounds, this.numFeatures);
+
+ // If dist < current bound and nodes are not in the
+ // same component, we continue
+ if (nodeDist < this.bounds[node1]) {
+ if (this.componentOfNode[node1] == this.componentOfNode[node2]
+ && this.componentOfNode[node1] >= 0)
+ return 0;
+ else {
+ /*
+ * Pass. This is the only condition in which
+ * the method will continue without exiting early
+ */
+ }
+ } else
+ return 0;
+
+
+ // If both nodes are leaves
+ if (node1Info.isLeaf() && node2Info.isLeaf()) {
+ newUpperBound = 0.0;
+ newLowerBound = Double.MAX_VALUE;
+
+ // Build the indices
+ pointIndices1 = new int[node1Info.end() - node1Info.start()];
+ pointIndices2 = new int[node2Info.end() - node2Info.start()];
+
+ // Populate the indices
+ for (i = node1Info.start(), j = 0; i < node1Info.end(); i++, j++)
+ pointIndices1[j] = this.idx_array[i];
+ for (i = node2Info.start(), j = 0; i < node2Info.end(); i++, j++)
+ pointIndices2[j] = this.idx_array[i];
+
+
+ for (i = 0; i < pointIndices1.length; i++) {
+ p = pointIndices1[i];
+ component1 = this.componentOfPoint[p];
+
+ if (this.coreDistance[p] > this.candidateDistance[component1])
+ continue;
+
+ for (j = 0; j < pointIndices2.length; j++) {
+ q = pointIndices2[j];
+ component2 = this.componentOfPoint[q];
+
+ if (this.coreDistance[q] > this.candidateDistance[component1])
+ continue;
+
+
+ // They belong to different components
+ if (component1 != component2) {
+
+ d = metric.getPartialDistance(this.tree_data_ref[p], this.tree_data_ref[q]);
+
+ mrDist = FastMath.max(
+ // Avoid repeated division overhead
+ (alpha == 1.0 ? d : d / alpha),
+
+ // Nested max
+ FastMath.max(this.coreDistance[p],
+ this.coreDistance[q]));
+
+ if (mrDist < this.candidateDistance[component1]) {
+ this.candidateDistance[component1] = mrDist;
+ this.candidateNeighbors[component1] = q;
+ this.candidatePoint[component1] = p;
+ }
+ }
+ } // end for j
+
+ newUpperBound = FastMath.max(newUpperBound, this.candidateDistance[component1]);
+ newLowerBound = FastMath.min(newLowerBound, this.candidateDistance[component1]);
+ } // end for i
+
+ // Calc new bound
+ newBound = FastMath.min(newUpperBound, newLowerBound + 2 * node1Info.radius());
+
+ // Reassign new bound to min bounds[node1]
+ if (newBound < this.bounds[node1]) {
+ this.bounds[node1] = newBound;
+
+ // propagate bounds up...
+ while (node1 > 0) {
+ parent = (node1 - 1) / 2;
+ left = 2 * parent + 1;
+ right = 2 * parent + 2;
+
+ newBound = FastMath.max(this.bounds[left], this.bounds[right]);
+ if (newBound < this.bounds[parent]) {
+ this.bounds[parent] = newBound;
+ node1 = parent;
+ } else break;
+ } // end while
+ } // end if inner
+ } // end case 1 if
+
+
+ // If node is a leaf or smaller than ref node
+ else if (node1Info.isLeaf()
+ || (!node2Info.isLeaf()
+ && node2Info.radius() > node1Info.radius())) {
+
+ left = 2 * node2 + 1;
+ right = 2 * node2 + 2;
+
+ node2Info = this.node_data_ref[left];
+ leftDist = kdTreeMinRDistDual(metric,
+ node1, left, node_bounds, this.numFeatures);
+
+ node2Info = this.node_data_ref[right];
+ rightDist = kdTreeMinRDistDual(metric,
+ node1, right, node_bounds, this.numFeatures);
+
+ if (leftDist < rightDist) {
+ this.dualTreeTraversal(node1, left);
+ this.dualTreeTraversal(node1, right);
+
+ } else { // Navigate in opposite order
+ this.dualTreeTraversal(node1, right);
+ this.dualTreeTraversal(node1, left);
+ }
+ } // end case 2 if
+
+
+ // Node is leaf or smaller than query node
+ else {
+ left = 2 * node1 + 1;
+ right = 2 * node1 + 2;
+
+ node1Info = this.node_data_ref[left];
+ leftDist = kdTreeMinRDistDual(metric,
+ left, node2, node_bounds, this.numFeatures);
+
+ node1Info = this.node_data_ref[right];
+ rightDist = kdTreeMinRDistDual(metric,
+ right, node2, node_bounds, this.numFeatures);
+
+ if (leftDist < rightDist) {
+ this.dualTreeTraversal(left, node2);
+ this.dualTreeTraversal(right, node2);
+
+ } else {
+ this.dualTreeTraversal(right, node2);
+ this.dualTreeTraversal(left, node2);
+ }
+ }
+
+
+ return 0;
+ }
+ }
+
+ protected class BallTreeBoruvAlg extends Boruvka {
+ final double[][] centroidDistances;
+
+ BallTreeBoruvAlg() {
+ super(false, new BallTree(
+ new Array2DRowRealMatrix(outer_tree.getDataRef(), false),
+ leafSize, metric, logger));
+
+ // Compute pairwise dist matrix for node_bounds
+ centroidDistances = Pairwise.getDistance(node_bounds[0], metric, false, false);
+ }
+
+ @Override
+ void computeBounds() {
+ int n, i, m;
+
+ // No longer doing breadth-first searches
+ Neighborhood queryResult =
+ TREE.query(tree_data_ref, minSamples, true, true);
+
+ double[][] knnDist = queryResult.getDistances();
+ int[][] knnIndices = queryResult.getIndices();
+
+ // Assign the core distance array...
+ this.coreDistance = new double[knnDist.length];
+ for (i = 0; i < coreDistance.length; i++)
+ coreDistance[i] = knnDist[i][minSamples - 1];
+
+ for (n = 0; n < numPoints; n++) {
+ for (i = minSamples - 1; i > 0; i--) {
+ m = knnIndices[n][i];
+
+ if (this.coreDistance[m] <= this.coreDistance[n]) {
+ this.candidatePoint[n] = n;
+ this.candidateNeighbors[n] = m;
+ this.candidateDistance[n] = this.coreDistance[n];
+ }
+ }
+ }
+
+ updateComponents();
+
+ for (n = 0; n < numNodes; n++)
+ this.bounds[n] = Double.MAX_VALUE;
+ }
+
+ @Override
+ int dualTreeTraversal(int node1, int node2) {
+ int[] pointIndices1, pointIndices2;
+ int i, j, p, q, parent //,child1, child2
+ ;
+
+ double nodeDist, d, mrDist, newBound,
+ newUpperBound, newLowerBound,
+ boundMax, boundMin,
+ leftDist, rightDist;
+
+ NodeData node1Info = node_data_ref[node1],
+ node2Info = node_data_ref[node2], parentInfo, leftInfo, rightInfo;
+
+ int component1, component2, left, right;
+
+ // Distance btwn query and ref nodes
+ nodeDist = ballTreeMinDistDual(node1Info.radius(),
+ node2Info.radius(), node1, node2,
+ this.centroidDistances);
+
+ // If dist < current bound and nodes are not in the
+ // same component, we continue
+ if (nodeDist < this.bounds[node1]) {
+ if (this.componentOfNode[node1] == this.componentOfNode[node2]
+ && this.componentOfNode[node1] >= 0)
+ return 0;
+ else {
+ /*
+ * Pass. This is the only condition in which
+ * the method will continue without exiting early
+ */
+ }
+ } else
+ return 0;
+
+
+ // If both nodes are leaves
+ if (node1Info.isLeaf() && node2Info.isLeaf()) {
+ newUpperBound = Double.NEGATIVE_INFINITY;
+ newLowerBound = Double.MAX_VALUE;
+ newBound = 0.0;
+
+ // Build the indices
+ pointIndices1 = new int[node1Info.end() - node1Info.start()];
+ pointIndices2 = new int[node2Info.end() - node2Info.start()];
+
+ // Populate the indices
+ for (i = node1Info.start(), j = 0; i < node1Info.end(); i++, j++)
+ pointIndices1[j] = this.idx_array[i];
+ for (i = node2Info.start(), j = 0; i < node2Info.end(); i++, j++)
+ pointIndices2[j] = this.idx_array[i];
+
+
+ for (i = 0; i < pointIndices1.length; i++) {
+ p = pointIndices1[i];
+ component1 = this.componentOfPoint[p];
+
+ if (this.coreDistance[p] > this.candidateDistance[component1])
+ continue;
+
+ for (j = 0; j < pointIndices2.length; j++) {
+ q = pointIndices2[j];
+ component2 = this.componentOfPoint[q];
+
+ if (this.coreDistance[q] > this.candidateDistance[component1])
+ continue;
+
+ // They belong to different components
+ if (component1 != component2) {
+ d = metric.getDistance(this.tree_data_ref[p], this.tree_data_ref[q]);
+
+ mrDist = FastMath.max(
+ // Avoid repeated division overhead
+ (alpha == 1.0 ? d : d / alpha),
+
+ // Nested max
+ FastMath.max(this.coreDistance[p],
+ this.coreDistance[q]));
+
+ if (mrDist < this.candidateDistance[component1]) {
+ this.candidateDistance[component1] = mrDist;
+ this.candidateNeighbors[component1] = q;
+ this.candidatePoint[component1] = p;
+ }
+ }
+ } // end for j
+
+ newUpperBound = FastMath.max(newUpperBound, this.candidateDistance[component1]);
+ newLowerBound = FastMath.min(newLowerBound, this.candidateDistance[component1]);
+ } // end for i
+
+ // Calc new bound
+ newBound = FastMath.min(newUpperBound, newLowerBound + 2 * node1Info.radius());
+
+ // Reassign new bound to min bounds[node1]
+ if (newBound < this.bounds[node1]) {
+ this.bounds[node1] = newBound;
+
+ // propagate bounds up...
+ while (node1 > 0) {
+ parent = (node1 - 1) / 2;
+ left = 2 * parent + 1;
+ right = 2 * parent + 2;
+
+ parentInfo = this.node_data_ref[parent];
+ leftInfo = this.node_data_ref[left];
+ rightInfo = this.node_data_ref[right];
+
+ boundMax = FastMath.max(this.bounds[left], this.bounds[right]);
+ boundMin = FastMath.min(this.bounds[left] + 2 * (parentInfo.radius() - leftInfo.radius()),
+ this.bounds[right] + 2 * (parentInfo.radius() - rightInfo.radius()));
+
+ if (boundMin > 0)
+ newBound = FastMath.min(boundMax, boundMin);
+ else
+ newBound = boundMax;
+
+ if (newBound < this.bounds[parent]) {
+ this.bounds[parent] = newBound;
+ node1 = parent;
+ } else break;
+ } // end while
+ } // end if inner
+ } // end case 1 if
+
+
+ // If node is a leaf or smaller than ref node
+ else if (node1Info.isLeaf()
+ || (!node2Info.isLeaf()
+ && node2Info.radius() > node1Info.radius())) {
+ left = 2 * node2 + 1;
+ right = 2 * node2 + 2;
+
+ node2Info = this.node_data_ref[left];
+ leftDist = ballTreeMinDistDual(node1Info.radius(),
+ node2Info.radius(), node1, left, this.centroidDistances);
+
+ node2Info = this.node_data_ref[right];
+ rightDist = ballTreeMinDistDual(node1Info.radius(),
+ node2Info.radius(), node1, right, this.centroidDistances);
+
+ if (leftDist < rightDist) {
+ this.dualTreeTraversal(node1, left);
+ this.dualTreeTraversal(node1, right);
+
+ } else { // Navigate in opposite order
+ this.dualTreeTraversal(node1, right);
+ this.dualTreeTraversal(node1, left);
+ }
+ } // end case 2 if
+
+
+ // Node is leaf or smaller than query node
+ else {
+ left = 2 * node1 + 1;
+ right = 2 * node1 + 2;
+
+ node1Info = this.node_data_ref[left];
+ leftDist = ballTreeMinDistDual(node1Info.radius(),
+ node2Info.radius(), left, node2, this.centroidDistances);
+
+ node1Info = this.node_data_ref[right];
+ rightDist = ballTreeMinDistDual(node1Info.radius(),
+ node2Info.radius(), right, node2, this.centroidDistances);
+
+ if (leftDist < rightDist) {
+ this.dualTreeTraversal(left, node2);
+ this.dualTreeTraversal(right, node2);
+
+ } else {
+ this.dualTreeTraversal(right, node2);
+ this.dualTreeTraversal(left, node2);
+ }
+ }
+
+
+ return 0;
+ }
+ }
+
+ protected final double[][] spanningTree() {
+ return alg.spanningTree();
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/CentroidClustererParameters.java b/src/main/java/com/clust4j/algo/CentroidClustererParameters.java
new file mode 100644
index 00000000..1d6bbbc0
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/CentroidClustererParameters.java
@@ -0,0 +1,49 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.algo.AbstractCentroidClusterer.InitializationStrategy;
+import org.apache.commons.math3.linear.RealMatrix;
+
+public abstract class CentroidClustererParameters extends BaseClustererParameters
+ implements UnsupervisedClassifierParameters, ConvergeablePlanner {
+
+ private static final long serialVersionUID = -1984508955251863189L;
+ protected int k = AbstractCentroidClusterer.DEF_K;
+ protected double minChange = AbstractCentroidClusterer.DEF_CONVERGENCE_TOLERANCE;
+
+ @Override
+ abstract public T fitNewModel(RealMatrix mat);
+
+ @Override
+ abstract public int getMaxIter();
+
+ abstract public InitializationStrategy getInitializationStrategy();
+
+ abstract public CentroidClustererParameters setConvergenceCriteria(final double min);
+
+ abstract public CentroidClustererParameters setInitializationStrategy(final InitializationStrategy strat);
+
+ final public int getK() {
+ return k;
+ }
+
+ @Override
+ final public double getConvergenceTolerance() {
+ return minChange;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/CentroidLearner.java b/src/main/java/com/clust4j/algo/CentroidLearner.java
new file mode 100644
index 00000000..86788b86
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/CentroidLearner.java
@@ -0,0 +1,83 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.except.ModelNotFitException;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+public interface CentroidLearner extends java.io.Serializable {
+ /**
+ * A standalone mixin class to handle predictions from {@link CentroidLearner}
+ * classes that are also a {@link BaseClassifier} and a subclass of {@link AbstractClusterer}.
+ *
+ * @author Taylor G Smith
+ */
+ static abstract class CentroidUtils {
+
+ /**
+ * Returns a matrix with the centroids.
+ *
+ * @param copy - whether or not to keep the reference or copy
+ * @return Array2DRowRealMatrix
+ */
+ protected static Array2DRowRealMatrix centroidsToMatrix(final Collection centroids, boolean copy) {
+ double[][] c = new double[centroids.size()][];
+
+ int i = 0;
+ for (double[] row : centroids)
+ c[i++] = row;
+
+ return new Array2DRowRealMatrix(c, copy);
+ }
+
+ /**
+ * Predict on an already-fit estimator
+ *
+ * @param model
+ * @param X
+ * @throws ModelNotFitException if the model isn't fit
+ */
+ protected static
+ int[] predict(E model, RealMatrix newData) throws ModelNotFitException {
+
+ /*
+ * First get the ground truth from the estimator...
+ */
+ final int[] labels = model.getLabels(); // throws exception
+
+ /*
+ * Now fit the NearestCentroids model, and predict
+ */
+ return new NearestCentroidParameters()
+ .setMetric(model.dist_metric) // if it fails, falls back to default Euclidean...
+ .setVerbose(false) // just to be sure in case default ever changes...
+ .fitNewModel(model.getData(), labels)
+ .predict(newData);
+ }
+ }
+
+
+ /**
+ * Returns the centroid records
+ *
+ * @return an ArrayList of the centroid records
+ */
+ public ArrayList getCentroids();
+}
diff --git a/src/main/java/com/clust4j/algo/Convergeable.java b/src/main/java/com/clust4j/algo/Convergeable.java
new file mode 100644
index 00000000..f0fded8b
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/Convergeable.java
@@ -0,0 +1,40 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+/**
+ * An interface to be implemented by {@link AbstractAutonomousClusterer}s that converge
+ *
+ * @author Taylor G Smith <tgsmith61591@gmail.com>
+ */
+public interface Convergeable extends ConvergeablePlanner {
+ public static final double DEF_TOL = 0.0;
+
+ /**
+ * Returns whether the algorithm has converged yet.
+ * If the algorithm has yet to be fit, it will return false.
+ *
+ * @return the state of algorithmic convergence
+ */
+ public boolean didConverge();
+
+ /**
+ * Get the count of iterations performed by the fit() method
+ *
+ * @return how many iterations were performed
+ */
+ public int itersElapsed();
+}
diff --git a/src/main/java/com/clust4j/algo/ConvergeablePlanner.java b/src/main/java/com/clust4j/algo/ConvergeablePlanner.java
new file mode 100644
index 00000000..aff28faf
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/ConvergeablePlanner.java
@@ -0,0 +1,34 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+interface ConvergeablePlanner extends java.io.Serializable {
+ /**
+ * The maximum number of iterations the algorithm
+ * is permitted before aborting without converging
+ *
+ * @return max iterations before convergence
+ */
+ public int getMaxIter();
+
+ /**
+ * This minimum change between iterations that will
+ * denote an iteration as having converged
+ *
+ * @return the min change for convergence
+ */
+ public double getConvergenceTolerance();
+}
diff --git a/src/main/java/com/clust4j/algo/DBSCAN.java b/src/main/java/com/clust4j/algo/DBSCAN.java
new file mode 100644
index 00000000..dd7ac24d
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/DBSCAN.java
@@ -0,0 +1,378 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.log.Log.Tag.Algo;
+import com.clust4j.log.LogTimer;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.metrics.pairwise.SimilarityMetric;
+import com.clust4j.utils.MatUtils;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Stack;
+
+
+/**
+ * DBSCAN (Density Based Spatial Clustering
+ * for Applications with Noise) is a data clustering algorithm proposed by Martin Ester,
+ * Hans-Peter Kriegel, Jorg Sander and Xiaowei Xu in 1996. It is a density-based clustering
+ * algorithm: given a set of points in some space, it groups together points that are
+ * closely packed together (points with many nearby neighbors), marking as outliers
+ * points that lie alone in low-density regions (whose nearest neighbors are too far away).
+ *
+ * @author Taylor G Smith <tgsmith61591@gmail.com>, adapted from sklearn implementation by Lars Buitinck
+ * @see DBSCAN,
+ * A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise
+ * @see {@link AbstractDensityClusterer}
+ */
+final public class DBSCAN extends AbstractDBSCAN {
+ /**
+ *
+ */
+ private static final long serialVersionUID = 6749407933012974992L;
+ final private int m;
+ final public static HashSet> UNSUPPORTED_METRICS;
+
+
+ /**
+ * Static initializer
+ */
+ static {
+ UNSUPPORTED_METRICS = new HashSet<>();
+ // Add metrics here if necessary...
+ }
+
+ @Override
+ final public boolean isValidMetric(GeometricallySeparable geo) {
+ return !UNSUPPORTED_METRICS.contains(geo.getClass()) && !(geo instanceof SimilarityMetric);
+ }
+
+ // Race conditions exist in retrieving either one of these...
+ private volatile int[] labels = null;
+ private volatile double[] sampleWeights = null;
+ private volatile boolean[] coreSamples = null;
+ private volatile int numClusters;
+ private volatile int numNoisey;
+
+
+ /**
+ * Constructs an instance of DBSCAN from the default epsilon
+ *
+ * @param data
+ */
+ protected DBSCAN(final RealMatrix data) {
+ this(data, DEF_EPS);
+ }
+
+
+ /**
+ * Constructs an instance of DBSCAN from the default planner values
+ *
+ * @param eps
+ * @param data
+ */
+ protected DBSCAN(final RealMatrix data, final double eps) {
+ this(data, new DBSCANParameters(eps));
+ }
+
+ /**
+ * Constructs an instance of DBSCAN from the provided builder
+ *
+ * @param builder
+ * @param data
+ */
+ protected DBSCAN(final RealMatrix data, final DBSCANParameters planner) {
+ super(data, planner);
+ this.m = data.getRowDimension();
+ this.eps = planner.getEps();
+
+ // Error handle...
+ if (this.eps <= 0.0)
+ error(new IllegalArgumentException("eps "
+ + "must be greater than 0.0"));
+
+ if (!isValidMetric(this.dist_metric)) {
+ warn(this.dist_metric.getName() + " is not valid for " + getName() + ". "
+ + "Falling back to default Euclidean dist");
+ setSeparabilityMetric(DEF_DIST);
+ }
+
+ logModelSummary();
+ }
+
+ @Override
+ final protected ModelSummary modelSummary() {
+ return new ModelSummary(new Object[]{
+ "Num Rows", "Num Cols", "Metric", "Epsilon", "Min Pts.", "Allow Par."
+ }, new Object[]{
+ m, data.getColumnDimension(), getSeparabilityMetric(),
+ eps, minPts,
+ parallel
+ });
+ }
+
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o instanceof DBSCAN) {
+ DBSCAN d = (DBSCAN) o;
+
+ /*
+ * This is a litmus test of
+ * whether the model has been fit yet.
+ */
+ if (null == this.labels ^ null == d.labels)
+ return false;
+
+ return super.equals(o) // tests for UUID
+ && MatUtils.equalsExactly(this.data.getDataRef(), d.data.getDataRef())
+ && this.eps == d.eps;
+ }
+
+ return false;
+ }
+
+ public double getEps() {
+ return eps;
+ }
+
+ @Override
+ public int[] getLabels() {
+ return super.handleLabelCopy(labels);
+ }
+
+ @Override
+ public String getName() {
+ return "DBSCAN";
+ }
+
+ @Override
+ protected DBSCAN fit() {
+ synchronized (fitLock) {
+
+ if (null != labels) // Then we've already fit this...
+ return this;
+
+
+ // First get the dist matrix
+ final LogTimer timer = new LogTimer();
+
+ // Do the neighborhood assignments, get sample weights, find core samples..
+ final LogTimer neighbTimer = new LogTimer();
+ labels = new int[m]; // Initialize labels...
+ sampleWeights = new double[m]; // Init sample weights...
+ coreSamples = new boolean[m];
+
+
+ // Fit the nearest neighbor model...
+ final LogTimer rnTimer = new LogTimer();
+ final RadiusNeighbors rnModel = new RadiusNeighbors(data,
+ new RadiusNeighborsParameters(eps)
+ .setSeed(getSeed())
+ .setMetric(getSeparabilityMetric())
+ .setVerbose(false))
+ .fit();
+
+ info("fit RadiusNeighbors model in " + rnTimer.toString());
+ int[][] nearest = rnModel.getNeighbors().getIndices();
+
+
+ int[] ptNeighbs;
+ ArrayList neighborhoods = new ArrayList<>();
+ int numCorePts = 0;
+ for (int i = 0; i < m; i++) {
+ // Each label inits to -1 as noise
+ labels[i] = NOISE_CLASS;
+ ptNeighbs = nearest[i];
+
+ // Add neighborhood...
+ int pts;
+ neighborhoods.add(ptNeighbs);
+ sampleWeights[i] = pts = ptNeighbs.length;
+ coreSamples[i] = pts >= minPts;
+
+ if (coreSamples[i])
+ numCorePts++;
+ }
+
+
+ // Log checkpoint
+ info("completed density neighborhood calculations in " + neighbTimer.toString());
+ info(numCorePts + " core point" + (numCorePts != 1 ? "s" : "") + " found");
+
+
+ // Label the points...
+ int nextLabel = 0, v;
+ final Stack stack = new Stack<>();
+ int[] neighb;
+
+
+ LogTimer stackTimer = new LogTimer();
+ for (int i = 0; i < m; i++) {
+ stackTimer = new LogTimer();
+
+ // Want to look at unlabeled OR core points...
+ if (labels[i] != NOISE_CLASS || !coreSamples[i])
+ continue;
+
+ // Depth-first search starting from i, ending at the non-core points.
+ // This is very similar to the classic algorithm for computing connected
+ // components, the difference being that we label non-core points as
+ // part of a cluster (component), but don't expand their neighborhoods.
+ int labelCt = 0;
+ while (true) {
+ if (labels[i] == NOISE_CLASS) {
+ labels[i] = nextLabel;
+ labelCt++;
+
+ if (coreSamples[i]) {
+ neighb = neighborhoods.get(i);
+
+ for (i = 0; i < neighb.length; i++) {
+ v = neighb[i];
+ if (labels[v] == NOISE_CLASS)
+ stack.push(v);
+ }
+ }
+ }
+
+
+ if (stack.size() == 0) {
+ fitSummary.add(new Object[]{
+ nextLabel, labelCt, stackTimer.formatTime(), stackTimer.wallTime()
+ });
+
+ break;
+ }
+
+ i = stack.pop();
+ }
+
+ nextLabel++;
+ }
+
+
+ // Count missing
+ numNoisey = 0;
+ for (int lab : labels) if (lab == NOISE_CLASS) numNoisey++;
+
+
+ // corner case: numNoisey == m (never gets a fit summary)
+ if (numNoisey == m)
+ fitSummary.add(new Object[]{
+ Double.NaN, 0, stackTimer.formatTime(), stackTimer.wallTime()
+ });
+
+
+ info((numClusters = nextLabel) + " cluster" + (nextLabel != 1 ? "s" : "") +
+ " identified, " + numNoisey + " record" + (numNoisey != 1 ? "s" : "") +
+ " classified noise");
+
+ // Encode to put in order
+ labels = new NoiseyLabelEncoder(labels).fit().getEncodedLabels();
+
+ sayBye(timer);
+ return this;
+ }
+
+ }// End train
+
+ @Override
+ public Algo getLoggerTag() {
+ return com.clust4j.log.Log.Tag.Algo.DBSCAN;
+ }
+
+ @Override
+ final protected Object[] getModelFitSummaryHeaders() {
+ return new Object[]{
+ "Cluster #", "Num. Core Pts.", "Iter. Time", "Wall"
+ };
+ }
+
+ @Override
+ public int getNumberOfIdentifiedClusters() {
+ return numClusters;
+ }
+
+ @Override
+ public int getNumberOfNoisePoints() {
+ return numNoisey;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public int[] predict(RealMatrix newData) {
+ final int[] fit_labels = getLabels(); // propagates errors
+ final int n = newData.getColumnDimension();
+
+ // Make sure matches dimensionally
+ if (n != this.data.getColumnDimension())
+ throw new DimensionMismatchException(n, data.getColumnDimension());
+
+ // Fit a radius model
+ RadiusNeighbors radiusModel =
+ new RadiusNeighborsParameters(eps) // no scale necessary; may already have been done
+ .setMetric(dist_metric)
+ .setSeed(getSeed())
+ .fitNewModel(data);
+
+ final int[] newLabels = new int[newData.getRowDimension()];
+ Neighborhood theHood = radiusModel.getNeighbors(newData);
+
+ int[][] indices = theHood.getIndices();
+
+ int[] idx_row;
+ for (int i = 0; i < indices.length; i++) {
+ idx_row = indices[i];
+
+ int current_class = NOISE_CLASS;
+ if (idx_row.length == 0) {
+ /*
+ * If there are no indices in this point's radius,
+ * we can just avoid the next step and exit early
+ */
+ } else { // otherwise, we know there is something in the radius--noise or other
+ int j = 0;
+ while (j < idx_row.length) {
+ current_class = fit_labels[idx_row[j]];
+
+ /*
+ * The indices are ordered ascendingly by dist.
+ * Even if the closest point is a noise point, it
+ * could be within a border point's radius, so we
+ * need to keep going.
+ */
+ if (NOISE_CLASS == current_class) {
+ j++;
+ } else {
+ break;
+ }
+ }
+ }
+
+ newLabels[i] = current_class;
+ }
+
+ return newLabels;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/DBSCANParameters.java b/src/main/java/com/clust4j/algo/DBSCANParameters.java
new file mode 100644
index 00000000..0aa59819
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/DBSCANParameters.java
@@ -0,0 +1,97 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.algo.AbstractDBSCAN.AbstractDBSCANParameters;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import java.util.Random;
+
+/**
+ * A builder class to provide an easier constructing
+ * interface to set custom parameters for DBSCAN
+ *
+ * @author Taylor G Smith
+ */
+final public class DBSCANParameters extends AbstractDBSCANParameters {
+ private static final long serialVersionUID = -5285244186285768512L;
+
+ private double eps = DBSCAN.DEF_EPS;
+
+
+ public DBSCANParameters() {
+ }
+
+ public DBSCANParameters(final double eps) {
+ this.eps = eps;
+ }
+
+
+ @Override
+ public DBSCAN fitNewModel(RealMatrix data) {
+ return new DBSCAN(data, this.copy()).fit();
+ }
+
+ @Override
+ public DBSCANParameters copy() {
+ return new DBSCANParameters(eps)
+ .setMinPts(minPts)
+ .setMetric(metric)
+ .setSeed(seed)
+ .setVerbose(verbose)
+ .setForceParallel(parallel);
+ }
+
+ public double getEps() {
+ return eps;
+ }
+
+ public DBSCANParameters setEps(final double eps) {
+ this.eps = eps;
+ return this;
+ }
+
+ @Override
+ public DBSCANParameters setMinPts(final int minPts) {
+ this.minPts = minPts;
+ return this;
+ }
+
+ @Override
+ public DBSCANParameters setSeed(final Random seed) {
+ this.seed = seed;
+ return this;
+ }
+
+ @Override
+ public DBSCANParameters setMetric(final GeometricallySeparable dist) {
+ this.metric = dist;
+ return this;
+ }
+
+ public DBSCANParameters setVerbose(final boolean v) {
+ this.verbose = v;
+ return this;
+ }
+
+ @Override
+ public DBSCANParameters setForceParallel(boolean b) {
+ this.parallel = b;
+ return this;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/HDBSCAN.java b/src/main/java/com/clust4j/algo/HDBSCAN.java
new file mode 100644
index 00000000..4a71a756
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/HDBSCAN.java
@@ -0,0 +1,1702 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.GlobalState;
+import com.clust4j.log.Log.Tag.Algo;
+import com.clust4j.log.LogTimer;
+import com.clust4j.log.Loggable;
+import com.clust4j.metrics.pairwise.Distance;
+import com.clust4j.metrics.pairwise.DistanceMetric;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.metrics.pairwise.Pairwise;
+import com.clust4j.utils.EntryPair;
+import com.clust4j.utils.MatUtils;
+import com.clust4j.utils.MatUtils.MatSeries;
+import com.clust4j.utils.QuadTup;
+import com.clust4j.utils.Series.Inequality;
+import com.clust4j.utils.VecUtils;
+import com.clust4j.utils.VecUtils.DoubleSeries;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.Precision;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+
+/**
+ * Hierarchical Density-Based Spatial Clustering of Applications with Noise.
+ * Performs {@link DBSCAN} over varying epsilon values and integrates the result to
+ * find a clustering that gives the best stability over epsilon. This allows
+ * HDBSCAN to find clusters of varying densities (unlike DBSCAN), and be more
+ * robust to parameter selection.
+ *
+ * @author Taylor G Smith, adapted from the Python
+ * HDBSCAN package, inspired by
+ * the paper by
+ * R. Campello, D. Moulavi, and J. Sander
+ */
+final public class HDBSCAN extends AbstractDBSCAN {
+ private static final long serialVersionUID = -5112901322434131541L;
+ public static final HDBSCAN_Algorithm DEF_ALGO = HDBSCAN_Algorithm.AUTO;
+ public static final double DEF_ALPHA = 1.0;
+ public static final boolean DEF_APPROX_MIN_SPAN = true;
+ public static final int DEF_LEAF_SIZE = 40;
+ public static final int DEF_MIN_CLUST_SIZE = 5;
+ /**
+ * The number of features that should trigger a boruvka implementation
+ */
+ static final int boruvka_n_features_ = 60;
+ static final Set> fast_metrics_;
+
+ /**
+ * Not final because can change if auto-enabled
+ */
+ protected HDBSCAN_Algorithm algo;
+ private final double alpha;
+ private final boolean approxMinSpanTree;
+ private final int min_cluster_size;
+ private final int leafSize;
+
+ private volatile HDBSCANLinkageTree tree = null;
+ private volatile double[][] dist_mat = null;
+ private volatile int[] labels = null;
+ private volatile int numClusters = -1;
+ private volatile int numNoisey = -1;
+ /**
+ * A copy of the data array inside the data matrix
+ */
+ private volatile double[][] dataData = null;
+
+
+ private interface HInitializer extends MetricValidator {
+ public HDBSCANLinkageTree initTree(HDBSCAN h);
+ }
+
+ public static enum HDBSCAN_Algorithm implements HInitializer {
+ /**
+ * Automatically selects the appropriate algorithm
+ * given dimensions of the dataset.
+ */
+ AUTO {
+ @Override
+ public HDBSCANLinkageTree initTree(HDBSCAN h) {
+ final Class extends GeometricallySeparable> clz = h.dist_metric.getClass();
+ final int n = h.data.getColumnDimension();
+
+ // rare situation... only if oddball dist
+ if (!fast_metrics_.contains(clz)) {
+ return GENERIC.initTree(h);
+ } else if (KDTree.VALID_METRICS.contains(clz)) {
+ return n > boruvka_n_features_ ?
+ BORUVKA_KDTREE.initTree(h) :
+ PRIMS_KDTREE.initTree(h);
+ }
+
+ // otherwise is valid balltree metric
+ return n > boruvka_n_features_ ?
+ BORUVKA_BALLTREE.initTree(h) :
+ PRIMS_BALLTREE.initTree(h);
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable g) {
+ throw new UnsupportedOperationException("auto does not have supported metrics");
+ }
+ },
+
+ /**
+ * Generates a minimum spanning tree using a pairwise,
+ * full distance matrix. Generally performs slower than
+ * the other algorithms for larger datasets, but has less
+ * setup overhead.
+ *
+ * @see Pairwise
+ */
+ GENERIC {
+ @Override
+ public GenericTree initTree(HDBSCAN h) {
+ // we set this in case it was called by auto
+ h.algo = this;
+ ensureMetric(h, this);
+ return h.new GenericTree();
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable g) {
+ HashSet> unsupported = new HashSet<>();
+
+ for (DistanceMetric d : Distance.binaryDistances())
+ unsupported.add(d.getClass());
+
+ // if we ever have MORE invalid ones, add them here...
+ return !unsupported.contains(g.getClass());
+ }
+ },
+
+ /**
+ * Prim's algorithm is a greedy algorithm that finds a
+ * minimum spanning tree for a weighted undirected graph.
+ * This means it finds a subset of the edges that forms a
+ * tree that includes every vertex, where the total weight
+ * of all the edges in the tree is minimized. This implementation
+ * internally uses a {@link KDTree} to handle the graph
+ * linkage function.
+ *
+ * @see KDTree
+ */
+ PRIMS_KDTREE {
+ @Override
+ public PrimsKDTree initTree(HDBSCAN h) {
+ // we set this in case it was called by auto
+ h.algo = this;
+ ensureMetric(h, this);
+ return h.new PrimsKDTree(h.leafSize);
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable g) {
+ return KDTree.VALID_METRICS.contains(g.getClass());
+ }
+ },
+
+ /**
+ * Prim's algorithm is a greedy algorithm that finds a
+ * minimum spanning tree for a weighted undirected graph.
+ * This means it finds a subset of the edges that forms a
+ * tree that includes every vertex, where the total weight
+ * of all the edges in the tree is minimized. This implementation
+ * internally uses a {@link BallTree} to handle the graph
+ * linkage function.
+ *
+ * @see BallTree
+ */
+ PRIMS_BALLTREE {
+ @Override
+ public PrimsBallTree initTree(HDBSCAN h) {
+ // we set this in case it was called by auto
+ h.algo = this;
+ ensureMetric(h, this);
+ return h.new PrimsBallTree(h.leafSize);
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable g) {
+ return BallTree.VALID_METRICS.contains(g.getClass());
+ }
+ },
+
+ /**
+ * Uses Boruvka's algorithm to find a minimum spanning
+ * tree. Internally uses a {@link KDTree} to handle the
+ * linkage function.
+ *
+ * @see BoruvkaAlgorithm
+ * @see KDTree
+ */
+ BORUVKA_KDTREE {
+ @Override
+ public BoruvkaKDTree initTree(HDBSCAN h) {
+ // we set this in case it was called by auto
+ h.algo = this;
+ ensureMetric(h, this);
+ return h.new BoruvkaKDTree(h.leafSize);
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable g) {
+ return KDTree.VALID_METRICS.contains(g.getClass());
+ }
+ },
+
+ /**
+ * Uses Boruvka's algorithm to find a minimum spanning
+ * tree. Internally uses a {@link BallTree} to handle the
+ * linkage function.
+ *
+ * @see BoruvkaAlgorithm
+ * @see BallTree
+ */
+ BORUVKA_BALLTREE {
+ @Override
+ public BoruvkaBallTree initTree(HDBSCAN h) {
+ // we set this in case it was called by auto
+ h.algo = this;
+ ensureMetric(h, this);
+ return h.new BoruvkaBallTree(h.leafSize);
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable g) {
+ return BallTree.VALID_METRICS.contains(g.getClass())
+ // For some reason Boruvka hates Canberra...
+ && !g.equals(Distance.CANBERRA)
+ ;
+ }
+ };
+
+ private static void ensureMetric(HDBSCAN h, HDBSCAN_Algorithm a) {
+ if (!a.isValidMetric(h.dist_metric)) {
+ h.warn(h.dist_metric.getName() + " is not valid for " + a +
+ ". Falling back to default Euclidean.");
+ h.setSeparabilityMetric(DEF_DIST);
+ }
+ }
+ }
+
+
+ static {
+ fast_metrics_ = new HashSet>();
+ fast_metrics_.addAll(KDTree.VALID_METRICS);
+ fast_metrics_.addAll(BallTree.VALID_METRICS);
+ }
+
+
+ /**
+ * Is the provided metric valid for this model?
+ */
+ @Override
+ final public boolean isValidMetric(GeometricallySeparable geo) {
+ return this.algo.isValidMetric(geo);
+ }
+
+
+ /**
+ * Constructs an instance of HDBSCAN from the default values
+ *
+ * @param data
+ */
+ protected HDBSCAN(final RealMatrix data) {
+ this(data, DEF_MIN_PTS);
+ }
+
+ /**
+ * Constructs an instance of HDBSCAN from the default values
+ *
+ * @param eps
+ * @param data
+ */
+ protected HDBSCAN(final RealMatrix data, final int minPts) {
+ this(data, new HDBSCANParameters(minPts));
+ }
+
+ /**
+ * Constructs an instance of HDBSCAN from the provided builder
+ *
+ * @param builder
+ * @param data
+ * @throws IllegalArgumentException if alpha is 0
+ */
+ protected HDBSCAN(final RealMatrix data, final HDBSCANParameters planner) {
+ super(data, planner);
+
+ this.algo = planner.getAlgo();
+ this.alpha = planner.getAlpha();
+ this.approxMinSpanTree = planner.getApprox();
+ this.min_cluster_size = planner.getMinClusterSize();
+ this.leafSize = planner.getLeafSize();
+
+ if (alpha <= 0.0) throw new IllegalArgumentException("alpha must be greater than 0");
+ if (leafSize < 1) throw new IllegalArgumentException("leafsize must be greater than 0");
+
+ logModelSummary();
+ }
+
+ @Override
+ final protected ModelSummary modelSummary() {
+ return new ModelSummary(new Object[]{
+ "Num Rows", "Num Cols", "Metric", "Algo.", "Allow Par.", "Min Pts.", "Min Clust. Size", "Alpha"
+ }, new Object[]{
+ data.getRowDimension(), data.getColumnDimension(),
+ getSeparabilityMetric(), algo,
+ parallel,
+ minPts, min_cluster_size, alpha
+ });
+ }
+
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o instanceof HDBSCAN) {
+ HDBSCAN h = (HDBSCAN) o;
+
+ /*
+ * Has one been fit and not the other?
+ */
+ if (null == this.labels ^ null == h.labels)
+ return false;
+
+ return super.equals(o) // UUID test
+ && MatUtils.equalsExactly(this.data.getDataRef(), h.data.getDataRef())
+ && (null == this.labels ? true : VecUtils.equalsExactly(this.labels, h.labels))
+ && this.algo.equals(h.algo)
+ && this.alpha == h.alpha
+ && this.leafSize == h.leafSize
+ && this.min_cluster_size == h.min_cluster_size;
+ }
+
+ return false;
+ }
+
+
+ /**
+ * This class extension is for the sake of testing; it restricts
+ * types to a subclass of Number and adds the method
+ * {@link CompQuadTup#almostEquals(CompQuadTup)} to measure whether
+ * values are equal within a margin of 1e-8.
+ *
+ * @param
+ * @param
+ * @param
+ * @param
+ * @author Taylor G Smith
+ */
+ protected final static class CompQuadTup
+ extends QuadTup {
+ private static final long serialVersionUID = -8699738868282635229L;
+
+ public CompQuadTup(ONE one, TWO two, THREE three, FOUR four) {
+ super(one, two, three, four);
+ }
+
+ /*
+ * For testing
+ */
+ public boolean almostEquals(CompQuadTup other) {
+ return Precision.equals(this.one.doubleValue(), other.one.doubleValue(), 1e-8)
+ && Precision.equals(this.two.doubleValue(), other.two.doubleValue(), 1e-8)
+ && Precision.equals(this.three.doubleValue(), other.three.doubleValue(), 1e-8)
+ && Precision.equals(this.four.doubleValue(), other.four.doubleValue(), 1e-8);
+ }
+ }
+
+ /**
+ * A simple extension of {@link HashSet} that takes
+ * an array or varargs as a constructor arg
+ *
+ * @param
+ * @author Taylor G Smith
+ */
+ protected final static class HSet extends HashSet {
+ private static final long serialVersionUID = 5185550036712184095L;
+
+ HSet(int size) {
+ super(size);
+ }
+
+ HSet(Collection extends T> coll) {
+ super(coll);
+ }
+ }
+
+ /**
+ * Constructs an {@link HSet} from the labels
+ *
+ * @author Taylor G Smith
+ */
+ protected final static class LabelHSetFactory {
+ static HSet build(int[] labs) {
+ HSet res = new HSet(labs.length);
+ for (int i : labs)
+ res.add(i);
+
+ return res;
+ }
+ }
+
+
+ /**
+ * Classes that will explicitly need to define
+ * reachability will have to implement this interface
+ */
+ interface ExplicitMutualReachability {
+ double[][] mutualReachability();
+ }
+
+ /**
+ * Mutual reachability is implicit when using
+ * {@link BoruvkaAlgorithm},
+ * thus we don't need these classes to implement
+ * {@link ExplicitMutualReachability#mutualReachability()}
+ */
+ interface Boruvka {
+ }
+
+ /**
+ * Mutual reachability is implicit when using
+ * {@link LinkageTreeUtils#mstLinkageCore_cdist},
+ * thus we don't need these classes to implement
+ * {@link ExplicitMutualReachability#mutualReachability()}
+ */
+ interface Prim {
+ }
+
+
+ /**
+ * Util mst linkage methods
+ *
+ * @author Taylor G Smith
+ */
+ protected static abstract class LinkageTreeUtils {
+
+ /**
+ * Perform a breadth first search on a tree
+ *
+ * @param hierarchy
+ * @param root
+ * @return
+ */
+ // Tested: passing
+ static ArrayList breadthFirstSearch(final double[][] hierarchy, final int root) {
+ ArrayList toProcess = new ArrayList<>(), tmp;
+ int dim = hierarchy.length, maxNode = 2 * dim, numPoints = maxNode - dim + 1;
+
+ toProcess.add(root);
+ ArrayList result = new ArrayList<>();
+ while (!toProcess.isEmpty()) {
+ result.addAll(toProcess);
+
+ tmp = new ArrayList<>();
+ for (Integer x : toProcess)
+ if (x >= numPoints)
+ tmp.add(x - numPoints);
+ toProcess = tmp;
+
+ tmp = new ArrayList<>();
+ if (!toProcess.isEmpty()) {
+ for (Integer row : toProcess)
+ for (int i = 0; i < 2; i++)
+ tmp.add((int) hierarchy[row][wraparoundIdxGet(hierarchy[row].length, i)]);
+
+ toProcess = tmp;
+ }
+ }
+
+ return result;
+ }
+
+ // Tested: passing
+ static TreeMap computeStability(ArrayList> condensed) {
+ double[] resultArr, births, lambdas = new double[condensed.size()];
+ int[] sizes = new int[condensed.size()], parents = new int[condensed.size()];
+ int child, parent, childSize, resultIdx, currentChild = -1, idx = 0, row = 0;
+ double lambda, minLambda = 0;
+
+
+ /* Populate parents, sizes and lambdas pre-sort and get min/max parent info
+ * ['parent', 'child', 'lambda', 'childSize']
+ */
+ int largestChild = Integer.MIN_VALUE,
+ minParent = Integer.MAX_VALUE,
+ maxParent = Integer.MIN_VALUE;
+ for (CompQuadTup q : condensed) {
+ parent = q.getFirst();
+ child = q.getSecond();
+ lambda = q.getThird();
+ childSize = q.getFourth();
+
+ if (child > largestChild)
+ largestChild = child;
+ if (parent < minParent)
+ minParent = parent;
+ if (parent > maxParent)
+ maxParent = parent;
+
+ parents[idx] = parent;
+ sizes[idx] = childSize;
+ lambdas[idx] = lambda;
+ idx++;
+ }
+
+ int numClusters = maxParent - minParent + 1;
+ births = VecUtils.rep(Double.NaN, largestChild + 1);
+
+ /*
+ * Perform sort, then get sorted lambdas and children
+ */
+ Collections.sort(condensed, new Comparator>() {
+ @Override
+ public int compare(QuadTup q1,
+ QuadTup q2) {
+ int cmp = q1.getSecond().compareTo(q2.getSecond());
+
+ if (cmp == 0) {
+ cmp = q1.getThird().compareTo(q2.getThird());
+ return cmp;
+ }
+
+ return cmp;
+ }
+ });
+
+
+ /*
+ * Go through sorted list...
+ */
+ for (row = 0; row < condensed.size(); row++) {
+ CompQuadTup q = condensed.get(row);
+ child = q.getSecond();
+ lambda = q.getThird();
+
+ if (child == currentChild)
+ minLambda = FastMath.min(minLambda, lambda);
+ else if (currentChild != -1) {
+ // Already been initialized
+ births[currentChild] = minLambda;
+ currentChild = child;
+ minLambda = lambda;
+ } else {
+ // Initialize
+ currentChild = child;
+ minLambda = lambda;
+ }
+ }
+
+ resultArr = new double[numClusters];
+
+
+ // Second loop
+ double birthParent;
+ for (idx = 0; idx < condensed.size(); idx++) {
+ parent = parents[idx];
+ lambda = lambdas[idx];
+ childSize = sizes[idx];
+ resultIdx = parent - minParent;
+
+ // the Cython exploits the C contiguous pointer array's
+ // out of bounds allowance (2.12325E-314), but we have to
+ // do a check for that...
+ birthParent = parent >= births.length ? GlobalState.Mathematics.TINY : births[parent];
+ resultArr[resultIdx] += (lambda - birthParent) * childSize;
+ }
+
+
+ double[] top = VecUtils.asDouble(VecUtils.arange(minParent, maxParent + 1));
+ double[][] mat = MatUtils.transpose(VecUtils.vstack(top, resultArr));
+
+ TreeMap result = new TreeMap<>();
+ for (idx = 0; idx < mat.length; idx++)
+ result.put((int) mat[idx][0], mat[idx][1]);
+
+ return result;
+ }
+
+ // Tested: passing
+ static ArrayList> condenseTree(final double[][] hierarchy, final int minSize) {
+ final int m = hierarchy.length;
+ int root = 2 * m,
+ numPoints = root / 2 + 1 /*Integer division*/,
+ nextLabel = numPoints + 1;
+
+ // Get node list from BFS
+ ArrayList nodeList = breadthFirstSearch(hierarchy, root), tmpList;
+ ArrayList> resultList = new ArrayList<>();
+
+ // Indices needing relabeling -- cython code assigns this to nodeList.size()
+ // but often times this is way too small and causes out of bounds exceptions...
+ // Changed to root + 1 on 02/01/2016; this should be the max node ever in the resultList
+ int[] relabel = new int[root + 1]; //nodeList.size()
+ boolean[] ignore = new boolean[root + 1];
+ double[] children;
+
+ double lambda;
+ int left, right, leftCount, rightCount;
+
+ // Sanity check
+ // System.out.println("Root: " + root + ", Relabel length: " + relabel.length + ", m: " + m + ", Relabel array: " + Arrays.toString(relabel));
+
+ // The cython code doesn't check for bounds and sloppily
+ // assigns this even if root > relabel.length.
+ relabel[root] = numPoints;
+
+
+ for (Integer node : nodeList) {
+
+ if (ignore[node] || node < numPoints)
+ continue;
+
+ children = hierarchy[wraparoundIdxGet(hierarchy.length, node - numPoints)];
+ left = (int) children[0];
+ right = (int) children[1];
+
+ if (children[2] > 0)
+ lambda = 1.0 / children[2];
+ else lambda = Double.POSITIVE_INFINITY;
+
+ if (left >= numPoints)
+ leftCount = (int) (hierarchy[wraparoundIdxGet(hierarchy.length, left - numPoints)][3]);
+ else leftCount = 1;
+
+ if (right >= numPoints)
+ rightCount = (int) (hierarchy[wraparoundIdxGet(hierarchy.length, right - numPoints)][3]);
+ else rightCount = 1;
+
+
+ if (leftCount >= minSize && rightCount >= minSize) {
+ relabel[left] = nextLabel++;
+ resultList.add(new CompQuadTup(
+ relabel[wraparoundIdxGet(relabel.length, node)],
+ relabel[wraparoundIdxGet(relabel.length, left)],
+ lambda, leftCount));
+
+ relabel[wraparoundIdxGet(relabel.length, right)] = nextLabel++;
+ resultList.add(new CompQuadTup(
+ relabel[wraparoundIdxGet(relabel.length, node)],
+ relabel[wraparoundIdxGet(relabel.length, right)],
+ lambda, rightCount));
+
+
+ } else if (leftCount < minSize && rightCount < minSize) {
+ tmpList = breadthFirstSearch(hierarchy, left);
+ for (Integer subnode : tmpList) {
+ if (subnode < numPoints)
+ resultList.add(new CompQuadTup(
+ relabel[wraparoundIdxGet(relabel.length, node)], subnode,
+ lambda, 1));
+ ignore[subnode] = true;
+ }
+
+ tmpList = breadthFirstSearch(hierarchy, right);
+ for (Integer subnode : tmpList) {
+ if (subnode < numPoints)
+ resultList.add(new CompQuadTup(
+ relabel[wraparoundIdxGet(relabel.length, node)], subnode,
+ lambda, 1));
+ ignore[subnode] = true;
+ }
+
+
+ } else if (leftCount < minSize) {
+ relabel[right] = relabel[node];
+ tmpList = breadthFirstSearch(hierarchy, left);
+
+ for (Integer subnode : tmpList) {
+ if (subnode < numPoints)
+ resultList.add(new CompQuadTup(
+ relabel[wraparoundIdxGet(relabel.length, node)], subnode,
+ lambda, 1));
+ ignore[subnode] = true;
+ }
+ } else {
+ relabel[left] = relabel[node];
+ tmpList = breadthFirstSearch(hierarchy, right);
+ for (Integer subnode : tmpList) {
+ if (subnode < numPoints)
+ resultList.add(new CompQuadTup(
+ relabel[wraparoundIdxGet(relabel.length, node)], subnode,
+ lambda, 1));
+ ignore[subnode] = true;
+ }
+ }
+ }
+
+ return resultList;
+ }
+
+ /**
+ * Generic linkage core method
+ *
+ * @param X
+ * @param m
+ * @return
+ */
+ static double[][] minSpanTreeLinkageCore(final double[][] X, final int m) { // Tested: passing
+ int[] node_labels, current_labels, tmp_labels;
+ double[] current_distances, left, right;
+ boolean[] label_filter;
+ boolean val;
+ int current_node, new_node_index, new_node, i, j, trueCt, idx;
+ DoubleSeries series;
+
+ double[][] result = new double[m - 1][3];
+ node_labels = VecUtils.arange(m);
+ current_node = 0;
+ current_distances = VecUtils.rep(Double.POSITIVE_INFINITY, m);
+ current_labels = node_labels;
+
+
+ for (i = 1; i < node_labels.length; i++) {
+
+ // Create the boolean mask; takes 2N to create mask and then filter
+ // however, creating the left vector concurrently
+ // trims off one N pass. This could be done using Vector.VecSeries
+ // but that would add an extra pass of N
+ idx = 0;
+ trueCt = 0;
+ label_filter = new boolean[current_labels.length];
+ for (j = 0; j < label_filter.length; j++) {
+ val = current_labels[j] != current_node;
+ if (val)
+ trueCt++;
+
+ label_filter[j] = val;
+ }
+
+ tmp_labels = new int[trueCt];
+ left = new double[trueCt];
+ for (j = 0; j < current_labels.length; j++) {
+ if (label_filter[j]) {
+ tmp_labels[idx] = current_labels[j];
+ left[idx] = current_distances[j];
+ idx++;
+ }
+ }
+
+ current_labels = tmp_labels;
+ right = new double[current_labels.length];
+ for (j = 0; j < right.length; j++)
+ right[j] = X[current_node][current_labels[j]];
+
+ // Build the current_distances vector
+ series = new DoubleSeries(left, Inequality.LESS_THAN, right);
+ current_distances = VecUtils.where(series, left, right);
+
+
+ // Get next iter values
+ new_node_index = VecUtils.argMin(current_distances);
+ new_node = current_labels[new_node_index];
+ result[i - 1][0] = (double) current_node;
+ result[i - 1][1] = (double) new_node;
+ result[i - 1][2] = current_distances[new_node_index];
+
+ current_node = new_node;
+ }
+
+ return result;
+ }
+
+ static double[][] minSpanTreeLinkageCore_cdist(final double[][] raw, final double[] coreDistances, GeometricallySeparable sep, final double alpha) {
+ double[] currentDists;
+ int[] inTreeArr;
+ double[][] resultArr;
+
+ int currentNode = 0, newNode, i, j, dim = raw.length;
+ double currentNodeCoreDist, rightVal, leftVal, coreVal, newDist;
+
+ resultArr = new double[dim - 1][3];
+ inTreeArr = new int[dim];
+ currentDists = VecUtils.rep(Double.POSITIVE_INFINITY, dim);
+
+
+ for (i = 1; i < dim; i++) {
+ inTreeArr[currentNode] = 1;
+ currentNodeCoreDist = coreDistances[currentNode];
+
+ newDist = Double.MAX_VALUE;
+ newNode = 0;
+
+ for (j = 0; j < dim; j++) {
+ if (inTreeArr[j] != 0)
+ continue; // only skips currentNode idx
+
+ rightVal = currentDists[j];
+ leftVal = sep.getDistance(raw[currentNode], raw[j]);
+
+ if (alpha != 1.0)
+ leftVal /= alpha;
+
+ coreVal = coreDistances[j];
+ if (currentNodeCoreDist > rightVal || coreVal > rightVal
+ || leftVal > rightVal) {
+ if (rightVal < newDist) { // Should always be the case?
+ newDist = rightVal;
+ newNode = j;
+ }
+
+ continue;
+ }
+
+
+ if (coreVal > currentNodeCoreDist) {
+ if (coreVal > leftVal)
+ leftVal = coreVal;
+ } else if (currentNodeCoreDist > leftVal) {
+ leftVal = currentNodeCoreDist;
+ }
+
+
+ if (leftVal < rightVal) {
+ currentDists[j] = leftVal;
+ if (leftVal < newDist) {
+ newDist = leftVal;
+ newNode = j;
+ }
+ } else if (rightVal < newDist) {
+ newDist = rightVal;
+ newNode = j;
+ }
+ } // end for j
+
+ resultArr[i - 1][0] = currentNode;
+ resultArr[i - 1][1] = newNode;
+ resultArr[i - 1][2] = newDist;
+ currentNode = newNode;
+ } // end for i
+
+
+ return resultArr;
+ }
+
+
+ /**
+ * The index may be -1; this will return
+ * the index of the length of the array minus
+ * the absolute value of the index in the case
+ * of negative indices, like the original Python
+ * code.
+ *
+ * @param array
+ * @param idx
+ * @return the index to be queried in wrap-around indexing
+ * @throws ArrayIndexOutOfBoundsException if the absolute value of the index
+ * exceeds the length of the array
+ */
+ static int wraparoundIdxGet(int array_len, int idx) {
+ int abs;
+ if ((abs = FastMath.abs(idx)) > array_len)
+ throw new ArrayIndexOutOfBoundsException(idx);
+ if (idx >= 0)
+ return idx;
+ return array_len - abs;
+ }
+
+ static double[][] mutualReachability(double[][] dist_mat, int minPts, double alpha) {
+ final int size = dist_mat.length;
+ minPts = FastMath.min(size - 1, minPts);
+
+ final double[] core_distances = MatUtils
+ .sortColsAsc(dist_mat)[minPts];
+
+ if (alpha != 1.0)
+ dist_mat = MatUtils.scalarDivide(dist_mat, alpha);
+
+
+ final MatSeries ser1 = new MatSeries(core_distances, Inequality.GREATER_THAN, dist_mat);
+ double[][] stage1 = MatUtils.where(ser1, core_distances, dist_mat);
+
+ stage1 = MatUtils.transpose(stage1);
+ final MatSeries ser2 = new MatSeries(core_distances, Inequality.GREATER_THAN, stage1);
+ final double[][] result = MatUtils.where(ser2, core_distances, stage1);
+
+ return MatUtils.transpose(result);
+ }
+ }
+
+
+ /**
+ * The top level class for all HDBSCAN linkage trees.
+ *
+ * @author Taylor G Smith
+ */
+ abstract class HDBSCANLinkageTree {
+ final HDBSCAN model;
+ final GeometricallySeparable metric;
+ final int m, n;
+
+ HDBSCANLinkageTree() {
+ model = HDBSCAN.this;
+ metric = model.getSeparabilityMetric();
+ m = model.data.getRowDimension();
+ n = model.data.getColumnDimension();
+ }
+
+ abstract double[][] link();
+ }
+
+
+ /**
+ * Algorithms that utilize {@link NearestNeighborHeapSearch}
+ * algorithms for mutual reachability
+ *
+ * @author Taylor G Smith
+ */
+ abstract class HeapSearchAlgorithm extends HDBSCANLinkageTree {
+ final int leafSize;
+
+ HeapSearchAlgorithm(int leafSize) {
+ super();
+ this.leafSize = leafSize;
+ }
+
+ abstract NearestNeighborHeapSearch getTree(double[][] X);
+
+ abstract String getTreeName();
+
+ /**
+ * The linkage function to be used for any classes
+ * implementing the {@link Prim} interface.
+ *
+ * @param dt
+ * @return
+ */
+ final double[][] primTreeLinkageFunction(double[][] dt) {
+ final int min_points = FastMath.min(m - 1, minPts);
+
+ LogTimer timer = new LogTimer();
+ model.info("building " + getTreeName() + " search tree...");
+ NearestNeighborHeapSearch tree = getTree(dt);
+ model.info("completed NearestNeighborHeapSearch construction in " + timer.toString());
+
+
+ // Query for dists to k nearest neighbors -- no longer use breadth first!
+ Neighborhood query = tree.query(dt, min_points, true, true);
+ double[][] dists = query.getDistances();
+ double[] coreDistances = MatUtils.getColumn(dists, dists[0].length - 1);
+
+ double[][] minSpanningTree = LinkageTreeUtils
+ .minSpanTreeLinkageCore_cdist(dt,
+ coreDistances, metric, alpha);
+
+ return label(MatUtils.sortAscByCol(minSpanningTree, 2));
+ }
+
+ /**
+ * The linkage function to be used for any classes
+ * implementing the {@link Boruvka} interface.
+ *
+ * @param dt
+ * @return
+ */
+ final double[][] boruvkaTreeLinkageFunction(double[][] dt) {
+ final int min_points = FastMath.min(m - 1, minPts);
+ int ls = FastMath.max(leafSize, 3);
+
+ model.info("building " + getTreeName() + " search tree...");
+
+ LogTimer timer = new LogTimer();
+ NearestNeighborHeapSearch tree = getTree(dt);
+ model.info("completed NearestNeighborHeapSearch construction in " + timer.toString());
+
+ // We can safely cast the metric to DistanceMetric at this point
+ final BoruvkaAlgorithm alg = new BoruvkaAlgorithm(tree, min_points,
+ (DistanceMetric) metric, ls / 3, approxMinSpanTree,
+ alpha, model);
+
+ double[][] minSpanningTree = alg.spanningTree();
+ return label(MatUtils.sortAscByCol(minSpanningTree, 2));
+ }
+ }
+
+ /**
+ * A class for HDBSCAN algorithms that utilize {@link KDTree}
+ * search spaces for segmenting nearest neighbors
+ *
+ * @author Taylor G Smith
+ */
+ abstract class KDTreeAlgorithm extends HeapSearchAlgorithm {
+ KDTreeAlgorithm(int leafSize) {
+ super(leafSize);
+ }
+
+ @Override
+ String getTreeName() {
+ return "KD";
+ }
+
+ @Override
+ final KDTree getTree(double[][] X) {
+ // We can safely cast the sep metric as DistanceMetric
+ // after the check in the constructor
+ return new KDTree(X, this.leafSize,
+ (DistanceMetric) metric, model);
+ }
+ }
+
+ /**
+ * A class for HDBSCAN algorithms that utilize {@link BallTree}
+ * search spaces for segmenting nearest neighbors
+ *
+ * @author Taylor G Smith
+ */
+ abstract class BallTreeAlgorithm extends HeapSearchAlgorithm {
+ BallTreeAlgorithm(int leafSize) {
+ super(leafSize);
+ }
+
+ @Override
+ String getTreeName() {
+ return "Ball";
+ }
+
+ @Override
+ final BallTree getTree(double[][] X) {
+ // We can safely cast the sep metric as DistanceMetric
+ // after the check in the constructor
+ return new BallTree(X, this.leafSize,
+ (DistanceMetric) metric, model);
+ }
+ }
+
+ /**
+ * Generic single linkage tree that uses an
+ * upper triangular distance matrix to compute
+ * mutual reachability
+ *
+ * @author Taylor G Smith
+ */
+ class GenericTree extends HDBSCANLinkageTree implements ExplicitMutualReachability {
+ GenericTree() {
+ super();
+
+ // The generic implementation requires the computation of an UT dist mat
+ final LogTimer s = new LogTimer();
+ dist_mat = Pairwise.getDistance(data, getSeparabilityMetric(), false, false);
+ info("completed distance matrix computation in " + s.toString());
+ }
+
+ @Override
+ double[][] link() {
+ final double[][] mutual_reachability = mutualReachability();
+ double[][] min_spanning_tree = LinkageTreeUtils
+ .minSpanTreeLinkageCore(mutual_reachability, m);
+
+ // Sort edges of the min_spanning_tree by weight
+ min_spanning_tree = MatUtils.sortAscByCol(min_spanning_tree, 2);
+ return label(min_spanning_tree);
+ }
+
+ @Override
+ public double[][] mutualReachability() {
+ /*// this shouldn't be able to happen...
+ if(null == dist_mat)
+ throw new IllegalClusterStateException("dist matrix is null; "
+ + "this only can happen when the model attempts to invoke "
+ + "mutualReachability on a tree without proper initialization "
+ + "or after the model has already been fit.");
+ */
+
+ return LinkageTreeUtils.mutualReachability(dist_mat, minPts, alpha);
+ }
+ }
+
+ /**
+ * An implementation of HDBSCAN using the {@link Prim} algorithm
+ * and leveraging {@link KDTree} search spaces
+ *
+ * @author Taylor G Smith
+ */
+ class PrimsKDTree extends KDTreeAlgorithm implements Prim {
+ PrimsKDTree(int leafSize) {
+ super(leafSize);
+ }
+
+ @Override
+ double[][] link() {
+ return primTreeLinkageFunction(dataData);
+ }
+ }
+
+ /**
+ * An implementation of HDBSCAN using the {@link Prim} algorithm
+ * and leveraging {@link BallTree} search spaces
+ *
+ * @author Taylor G Smith
+ */
+ class PrimsBallTree extends BallTreeAlgorithm implements Prim {
+ PrimsBallTree(int leafSize) {
+ super(leafSize);
+ }
+
+ @Override
+ double[][] link() {
+ return primTreeLinkageFunction(dataData);
+ }
+ }
+
+ class BoruvkaKDTree extends KDTreeAlgorithm implements Boruvka {
+ BoruvkaKDTree(int leafSize) {
+ super(leafSize);
+ }
+
+ @Override
+ double[][] link() {
+ return boruvkaTreeLinkageFunction(dataData);
+ }
+ }
+
+ class BoruvkaBallTree extends BallTreeAlgorithm implements Boruvka {
+ BoruvkaBallTree(int leafSize) {
+ super(leafSize);
+ }
+
+ @Override
+ double[][] link() {
+ return boruvkaTreeLinkageFunction(dataData);
+ }
+ }
+
+ /**
+ * A base class for any unify finder classes
+ * to extend. These should help join nodes and
+ * branches from trees.
+ *
+ * @author Taylor G Smith
+ */
+ abstract static class UnifiedFinder {
+ final int SIZE;
+
+ UnifiedFinder(int N) {
+ this.SIZE = N;
+ }
+
+ /**
+ * Wraps the index in a python way (-1 = last index).
+ * Easier and more concise than having lots of references to
+ * {@link LinkageTreeUtils#wraparoundIdxGet(int, int)}
+ *
+ * @param i
+ * @param j
+ * @return
+ */
+ static int wrap(int i, int j) {
+ return LinkageTreeUtils.wraparoundIdxGet(i, j);
+ }
+
+ int wrap(int i) {
+ return wrap(SIZE, i);
+ }
+
+ abstract void union(int m, int n);
+
+ abstract int find(int x);
+ }
+
+ // Tested: passing
+ static class TreeUnionFind extends UnifiedFinder {
+ int[][] dataArr;
+ boolean[] is_component;
+
+ public TreeUnionFind(int size) {
+ super(size);
+ dataArr = new int[size][2];
+
+ // First col should be arange to size
+ for (int i = 0; i < size; i++)
+ dataArr[i][0] = i;
+
+ is_component = VecUtils.repBool(true, size);
+ }
+
+ @Override
+ public void union(int x, int y) {
+ int x_root = find(x);
+ int y_root = find(y);
+
+ int x1idx = wrap(x_root);
+ int y1idx = wrap(y_root);
+
+ int dx1 = dataArr[x1idx][1];
+ int dy1 = dataArr[y1idx][1];
+
+ if (dx1 < dy1)
+ dataArr[x1idx][0] = y_root;
+ else if (dx1 > dy1)
+ dataArr[y1idx][0] = x_root;
+ else {
+ dataArr[y1idx][0] = x_root;
+ dataArr[x1idx][1] += 1;
+ }
+ }
+
+ @Override
+ public int find(int x) {
+ final int idx = wrap(x);
+ if (dataArr[idx][0] != x) {
+ dataArr[idx][0] = find(dataArr[idx][0]);
+ is_component[idx] = false;
+ }
+
+ return dataArr[idx][0];
+ }
+
+ /**
+ * Returns all non-zero indices in is_component
+ *
+ * @return
+ */
+ int[] components() {
+ final ArrayList h = new ArrayList<>();
+ for (int i = 0; i < is_component.length; i++)
+ if (is_component[i])
+ h.add(i);
+
+ int idx = 0;
+ int[] out = new int[h.size()];
+ for (Integer i : h)
+ out[idx++] = i;
+
+ return out;
+ }
+ }
+
+ // Tested: passing
+ static class UnionFind extends UnifiedFinder {
+ int[] parent, size;
+ int nextLabel;
+
+ public UnionFind(int N) {
+ super(N);
+ parent = VecUtils.repInt(-1, 2 * N - 1);
+ nextLabel = N;
+
+ size = new int[2 * N - 1];
+ for (int i = 0; i < size.length; i++)
+ size[i] = i >= N ? 0 : 1; // if N == 5 [1,1,1,1,1,0,0,0,0]
+ }
+
+ int fastFind(int n) {
+ int p = n //,tmp
+ ;
+
+ while (parent[wrap(parent.length, n)] != -1)
+ n = parent[wrap(parent.length, n)];
+
+ // Incredibly enraging to debug -- skeptics be warned
+ while (parent[wrap(parent.length, p)] != n) {
+ //System.out.println("First: {p:" + p + ", parent[p]:" +parent[wrap(parent.length, p)] + ", n:" +n+"}");
+
+ //tmp = p;
+ p = parent[wrap(parent.length, p)];
+ parent[wrap(parent.length, p)] = n;
+
+ //System.out.println("Second: {p:" + p + ", parent[p]:" +parent[wrap(parent.length, p)] + ", n:" +n+"}");
+ //System.out.println(Arrays.toString(parent));
+ }
+
+ return n;
+ }
+
+ @Override
+ public int find(int n) {
+ while (parent[wrap(parent.length, n)] != -1)
+ n = parent[wrap(parent.length, n)];
+ return n;
+ }
+
+ @Override
+ public void union(final int m, final int n) {
+ int mWrap = wrap(size.length, m);
+ int nWrap = wrap(size.length, n);
+
+ size[nextLabel] = size[mWrap] + size[nWrap];
+ parent[mWrap] = nextLabel;
+ parent[nWrap] = nextLabel;
+ size[nextLabel] = size[mWrap] + size[nWrap];
+ nextLabel++;
+ return;
+ }
+
+ @Override
+ public String toString() {
+ return "Parent arr: " + Arrays.toString(parent) + "; " +
+ "Sizes: " + Arrays.toString(size) + "; " +
+ "Parent: " + Arrays.toString(parent);
+ }
+ }
+
+
+ protected static int[] doLabeling(ArrayList> tree,
+ ArrayList clusters, TreeMap clusterMap) {
+
+ CompQuadTup quad;
+ int rootCluster, parent, child, n = tree.size(), cluster, i;
+ int[] resultArr, parentArr = new int[n], childArr = new int[n];
+ UnifiedFinder unionFind;
+
+ // [parent, child, lambda, size]
+ int maxParent = Integer.MIN_VALUE;
+ int minParent = Integer.MAX_VALUE;
+ for (i = 0; i < n; i++) {
+ quad = tree.get(i);
+ parentArr[i] = quad.getFirst();
+ childArr[i] = quad.getSecond();
+
+ if (quad.getFirst() < minParent)
+ minParent = quad.getFirst();
+ if (quad.getFirst() > maxParent)
+ maxParent = quad.getFirst();
+ }
+
+ rootCluster = minParent;
+ resultArr = new int[rootCluster];
+ unionFind = new TreeUnionFind(maxParent + 1);
+
+ for (i = 0; i < n; i++) {
+ child = childArr[i];
+ parent = parentArr[i];
+ if (!clusters.contains(child))
+ unionFind.union(parent, child);
+ }
+
+ for (i = 0; i < rootCluster; i++) {
+ cluster = unionFind.find(i);
+ if (cluster <= rootCluster)
+ resultArr[i] = NOISE_CLASS;
+ else
+ resultArr[i] = clusterMap.get(cluster);
+ }
+
+ return resultArr;
+ }
+
+ @Override
+ protected HDBSCAN fit() {
+ synchronized (fitLock) {
+ if (null != labels) // Then we've already fit this...
+ return this;
+
+
+ // Meant to prevent multiple .getData() copy calls
+ final LogTimer timer = new LogTimer();
+ dataData = this.data.getData();
+
+ // Build the tree
+ info("constructing HDBSCAN single linkage dendrogram: " + algo);
+ this.tree = algo.initTree(this);
+
+
+ LogTimer treeTimer = new LogTimer();
+ final double[][] lab_tree = tree.link(); // returns the result of the label(..) function
+ info("completed tree building in " + treeTimer.toString());
+
+
+ info("converting tree to labels (" + lab_tree.length + " x " + lab_tree[0].length + ")");
+ LogTimer labTimer = new LogTimer();
+ labels = treeToLabels(dataData, lab_tree, min_cluster_size, this);
+
+
+ // Wrap up...
+ info("completed cluster labeling in " + labTimer.toString());
+
+
+ // Count missing
+ numNoisey = 0;
+ for (int lab : labels) if (lab == NOISE_CLASS) numNoisey++;
+
+
+ int nextLabel = LabelHSetFactory.build(labels).size() - (numNoisey > 0 ? 1 : 0);
+ info((numClusters = nextLabel) + " cluster" + (nextLabel != 1 ? "s" : "") +
+ " identified, " + numNoisey + " record" + (numNoisey != 1 ? "s" : "") +
+ " classified noise");
+
+ // Need to encode labels to maintain order
+ final NoiseyLabelEncoder encoder = new NoiseyLabelEncoder(labels).fit();
+ labels = encoder.getEncodedLabels();
+
+
+
+ /*
+ * In this portion, we build the fit summary... HDBSCAN is hard
+ * to iteratively update on status, so we will merely provide summary
+ * statistics on the class labels. Since it's not a centroid-based model
+ * it wouldn't make since to track any metrics such as WSS, so we'll
+ * leave it at simple counts and pcts.
+ */
+ String label_rep;
+ int[] ordered_label_classes = VecUtils.reorder(encoder.getClasses(), VecUtils.argSort(encoder.getClasses()));
+ for (int label : ordered_label_classes) {
+ label_rep = label + (NOISE_CLASS == label ? " (noise)" : "");
+
+ int count = VecUtils.sum(new VecUtils.IntSeries(labels, Inequality.EQUAL_TO, label).get());
+ double pct = (double) count / (double) labels.length;
+
+ // log the summary
+ fitSummary.add(new Object[]{
+ label_rep,
+ count,
+ pct,
+ timer.wallTime()
+ });
+ }
+
+
+ // Close this model out
+ sayBye(timer);
+
+
+ // Clean anything with big overhead..
+ dataData = null;
+ dist_mat = null;
+ tree = null;
+
+ return this;
+ }
+ }
+
+
+ @Override
+ public int[] getLabels() {
+ return super.handleLabelCopy(labels);
+ }
+
+ @Override
+ public Algo getLoggerTag() {
+ return com.clust4j.log.Log.Tag.Algo.HDBSCAN;
+ }
+
+ @Override
+ public String getName() {
+ return "HDBSCAN";
+ }
+
+ @Override
+ public int getNumberOfIdentifiedClusters() {
+ return numClusters;
+ }
+
+ @Override
+ public int getNumberOfNoisePoints() {
+ return numNoisey;
+ }
+
+ /**
+ * Break up the getLabels method
+ * into numerous smaller ones.
+ *
+ * @author Taylor G Smith
+ */
+ abstract static class GetLabelUtils {
+ /**
+ * Descendingly sort the keys of the map and return
+ * them in order, but eliminate the very smallest key
+ *
+ * @param stability
+ * @return
+ */
+ protected static ArrayList descSortedKeySet(TreeMap stability) {
+ int ct = 0;
+ ArrayList nodeList = new ArrayList<>();
+ for (T d : stability.descendingKeySet())
+ if (++ct < stability.size()) // exclude the root...
+ nodeList.add(d);
+
+ return nodeList;
+ }
+
+ /**
+ * Get tuples where child size is over one
+ *
+ * @param tree
+ * @return
+ */
+ protected static EntryPair, Integer> childSizeGtOneAndMaxChild(ArrayList> tree) {
+ ArrayList out = new ArrayList<>();
+ int max = Integer.MIN_VALUE;
+
+ // [parent, child, lambda, size]
+ for (CompQuadTup tup : tree) {
+ if (tup.getFourth() > 1)
+ out.add(new double[]{
+ tup.getFirst(),
+ tup.getSecond(),
+ tup.getThird(),
+ tup.getFourth()
+ });
+ else if (tup.getFourth() == 1)
+ max = FastMath.max(max, tup.getSecond());
+ }
+
+ return new EntryPair<>(out, max + 1);
+ }
+
+ protected static TreeMap initNodeMap(ArrayList nodes) {
+ TreeMap out = new TreeMap<>();
+ for (Integer i : nodes)
+ out.put(i, true);
+ return out;
+ }
+
+ protected static double subTreeStability(ArrayList clusterTree,
+ int node, TreeMap stability) {
+ double sum = 0;
+
+ // [parent, child, lambda, size]
+ for (double[] d : clusterTree)
+ if ((int) d[0] == node)
+ sum += stability.get((int) d[1]);
+
+ return sum;
+ }
+
+ protected static ArrayList breadthFirstSearchFromClusterTree(ArrayList tree, Integer bfsRoot) {
+ int child, parent;
+ ArrayList result = new ArrayList<>();
+ ArrayList toProcess = new ArrayList();
+ ArrayList tmp;
+
+ toProcess.add(bfsRoot);
+
+ // [parent, child, lambda, size]
+ while (toProcess.size() > 0) {
+ result.addAll(toProcess);
+
+ // python code:
+ // to_process = tree['child'][np.in1d(tree['parent'], to_process)]
+ // For all tuples, if the parent is in toProcess, then
+ // add the child to the new list
+ tmp = new ArrayList();
+ for (double[] d : tree) {
+ parent = (int) d[0];
+ child = (int) d[1];
+
+ if (toProcess.contains(parent))
+ tmp.add(child);
+ }
+
+ toProcess = tmp;
+ }
+
+ return result;
+ }
+ }
+
+ protected static int[] getLabels(ArrayList> condensed,
+ TreeMap stability) {
+
+ double subTreeStability;
+ ArrayList clusters = new ArrayList();
+ HSet clusterSet;
+ TreeMap clusterMap = new TreeMap<>(),
+ reverseClusterMap = new TreeMap<>();
+
+ // Get descending sorted key set
+ ArrayList nodeList = GetLabelUtils.descSortedKeySet(stability);
+
+ // Get tuples where child size > 1
+ EntryPair, Integer> entry = GetLabelUtils.childSizeGtOneAndMaxChild(condensed);
+ ArrayList clusterTree = entry.getKey();
+
+ // Map of nodes to whether it's a cluster
+ TreeMap isCluster = GetLabelUtils.initNodeMap(nodeList);
+
+ // Get num points
+ //int numPoints = entry.getValue();
+
+ // Iter over nodes
+ for (Integer node : nodeList) {
+ subTreeStability = GetLabelUtils.subTreeStability(clusterTree, node, stability);
+
+ if (subTreeStability > stability.get(node)) {
+ isCluster.put(node, false);
+ stability.put(node, subTreeStability);
+ } else {
+ for (Integer subNode : GetLabelUtils.breadthFirstSearchFromClusterTree(clusterTree, node))
+ if (subNode.intValue() != node)
+ isCluster.put(subNode, false);
+ }
+
+ }
+
+ // Now add to clusters
+ for (Map.Entry c : isCluster.entrySet())
+ if (c.getValue())
+ clusters.add(c.getKey());
+ clusterSet = new HSet(clusters);
+
+ // Build cluster map
+ int n = 0;
+ for (Integer clust : clusterSet) {
+ clusterMap.put(clust, n);
+ reverseClusterMap.put(n, clust);
+ n++;
+ }
+
+ return doLabeling(condensed, clusters, clusterMap);
+ }
+
+ // Tested: passing
+ static double[][] label(final double[][] tree) {
+ double[][] result;
+ int a, aa, b, bb, index;
+ final int m = tree.length, n = tree[0].length, N = m + 1;
+ double delta;
+
+ result = new double[m][n + 1];
+ UnionFind U = new UnionFind(N);
+
+ for (index = 0; index < m; index++) {
+
+ a = (int) tree[index][0];
+ b = (int) tree[index][1];
+ delta = tree[index][2];
+
+ aa = U.fastFind(a);
+ bb = U.fastFind(b);
+
+ result[index][0] = aa;
+ result[index][1] = bb;
+ result[index][2] = delta;
+ result[index][3] = U.size[aa] + U.size[bb];
+
+ U.union(aa, bb);
+ }
+
+ return result;
+ }
+
+ /*
+ protected static double[][] singleLinkage(final double[][] dists) {
+ final double[][] hierarchy = LinkageTreeUtils.minSpanTreeLinkageCore(dists, dists.length);
+ return label(MatUtils.sortAscByCol(hierarchy, 2));
+ }
+ */
+
+ protected static int[] treeToLabels(final double[][] X,
+ final double[][] single_linkage_tree, final int min_size) {
+ return treeToLabels(X, single_linkage_tree, min_size, null);
+ }
+
+ protected static int[] treeToLabels(final double[][] X,
+ final double[][] single_linkage_tree, final int min_size, Loggable logger) {
+
+ final ArrayList> condensed =
+ LinkageTreeUtils.condenseTree(single_linkage_tree, min_size);
+ final TreeMap stability = LinkageTreeUtils.computeStability(condensed);
+ return getLabels(condensed, stability);
+ }
+
+ @Override
+ final protected Object[] getModelFitSummaryHeaders() {
+ return new Object[]{
+ "Class Label", "Num. Instances", "Pct. Instances", "Wall"
+ };
+ }
+
+ @Override
+ public int[] predict(RealMatrix newData) {
+ @SuppressWarnings("unused") final int[] fit_labels = getLabels(); // throws the exception if not fit
+ final int n = newData.getColumnDimension();
+
+ if (n != this.data.getColumnDimension())
+ throw new DimensionMismatchException(n, newData.getColumnDimension());
+
+ // TODO: how to predict these???
+ throw new UnsupportedOperationException("HDBSCAN does not currently support predictions");
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/HDBSCANParameters.java b/src/main/java/com/clust4j/algo/HDBSCANParameters.java
new file mode 100644
index 00000000..99aec92e
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/HDBSCANParameters.java
@@ -0,0 +1,144 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.algo.AbstractDBSCAN.AbstractDBSCANParameters;
+import com.clust4j.algo.HDBSCAN.HDBSCAN_Algorithm;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import java.util.Random;
+
+/**
+ * A builder class to provide an easier constructing
+ * interface to set custom parameters for HDBSCAN
+ *
+ * @author Taylor G Smith
+ */
+final public class HDBSCANParameters extends AbstractDBSCANParameters {
+ private static final long serialVersionUID = 7197585563308908685L;
+
+ private HDBSCAN_Algorithm algo = HDBSCAN.DEF_ALGO;
+ private double alpha = HDBSCAN.DEF_ALPHA;
+ private boolean approxMinSpanTree = HDBSCAN.DEF_APPROX_MIN_SPAN;
+ private int min_cluster_size = HDBSCAN.DEF_MIN_CLUST_SIZE;
+ private int leafSize = HDBSCAN.DEF_LEAF_SIZE;
+
+
+ public HDBSCANParameters() {
+ this(HDBSCAN.DEF_MIN_PTS);
+ }
+
+ public HDBSCANParameters(final int minPts) {
+ this.minPts = minPts;
+ }
+
+
+ @Override
+ public HDBSCAN fitNewModel(RealMatrix data) {
+ return new HDBSCAN(data, this.copy()).fit();
+ }
+
+ @Override
+ public HDBSCANParameters copy() {
+ return new HDBSCANParameters(minPts)
+ .setAlgo(algo)
+ .setAlpha(alpha)
+ .setApprox(approxMinSpanTree)
+ .setLeafSize(leafSize)
+ .setMinClustSize(min_cluster_size)
+ .setMinPts(minPts)
+ .setMetric(metric)
+ .setSeed(seed)
+ .setVerbose(verbose)
+ .setForceParallel(parallel);
+ }
+
+ public HDBSCAN_Algorithm getAlgo() {
+ return this.algo;
+ }
+
+ public HDBSCANParameters setAlgo(final HDBSCAN_Algorithm algo) {
+ this.algo = algo;
+ return this;
+ }
+
+ public double getAlpha() {
+ return alpha;
+ }
+
+ public HDBSCANParameters setAlpha(final double a) {
+ this.alpha = a;
+ return this;
+ }
+
+ public boolean getApprox() {
+ return approxMinSpanTree;
+ }
+
+ public HDBSCANParameters setApprox(final boolean b) {
+ this.approxMinSpanTree = b;
+ return this;
+ }
+
+ public int getLeafSize() {
+ return leafSize;
+ }
+
+ public HDBSCANParameters setLeafSize(final int leafSize) {
+ this.leafSize = leafSize;
+ return this;
+ }
+
+ public int getMinClusterSize() {
+ return min_cluster_size;
+ }
+
+ public HDBSCANParameters setMinClustSize(final int min) {
+ this.min_cluster_size = min;
+ return this;
+ }
+
+ @Override
+ public HDBSCANParameters setMinPts(final int minPts) {
+ this.minPts = minPts;
+ return this;
+ }
+
+ @Override
+ public HDBSCANParameters setForceParallel(boolean b) {
+ this.parallel = b;
+ return this;
+ }
+
+ @Override
+ public HDBSCANParameters setSeed(final Random seed) {
+ this.seed = seed;
+ return this;
+ }
+
+ @Override
+ public HDBSCANParameters setMetric(final GeometricallySeparable dist) {
+ this.metric = dist;
+ return this;
+ }
+
+ public HDBSCANParameters setVerbose(final boolean v) {
+ this.verbose = v;
+ return this;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/HierarchicalAgglomerative.java b/src/main/java/com/clust4j/algo/HierarchicalAgglomerative.java
new file mode 100644
index 00000000..f514fe8d
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/HierarchicalAgglomerative.java
@@ -0,0 +1,659 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.NamedEntity;
+import com.clust4j.kernel.CircularKernel;
+import com.clust4j.kernel.LogKernel;
+import com.clust4j.log.Log.Tag.Algo;
+import com.clust4j.log.LogTimer;
+import com.clust4j.metrics.pairwise.Distance;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.metrics.scoring.SupervisedMetric;
+import com.clust4j.utils.MatUtils;
+import com.clust4j.utils.SimpleHeap;
+import com.clust4j.utils.VecUtils;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+
+import static com.clust4j.metrics.scoring.UnsupervisedMetric.SILHOUETTE;
+
+/**
+ * Agglomerative clustering is a hierarchical clustering process in
+ * which each input record initially is mapped to its own cluster.
+ * Progressively, each cluster is merged by locating the least dissimilar
+ * clusters in a M x M distance matrix, merging them, removing the corresponding
+ * rows and columns from the distance matrix and adding a new row/column vector
+ * of distances corresponding to the new cluster until there is one cluster.
+ *
+ * Agglomerative clustering does not scale well to large data, performing
+ * at O(n2) computationally, yet it outperforms its cousin, Divisive Clustering
+ * (DIANA), which performs at O(2n).
+ *
+ * @author Taylor G Smith <tgsmith61591@gmail.com>
+ * @see Agglomerative Clustering
+ * @see Divisive Clustering
+ */
+final public class HierarchicalAgglomerative extends AbstractPartitionalClusterer implements UnsupervisedClassifier {
+ /**
+ *
+ */
+ private static final long serialVersionUID = 7563413590708853735L;
+ public static final Linkage DEF_LINKAGE = Linkage.WARD;
+ final static HashSet> comp_avg_unsupported;
+
+ static {
+ comp_avg_unsupported = new HashSet<>();
+ comp_avg_unsupported.add(CircularKernel.class);
+ comp_avg_unsupported.add(LogKernel.class);
+ }
+
+ /**
+ * Which {@link Linkage} to use for the clustering algorithm
+ */
+ final Linkage linkage;
+
+ interface LinkageTreeBuilder extends MetricValidator {
+ public HierarchicalDendrogram buildTree(HierarchicalAgglomerative h);
+ }
+
+ /**
+ * The linkages for agglomerative clustering.
+ *
+ * @author Taylor G Smith
+ */
+ public enum Linkage implements java.io.Serializable, LinkageTreeBuilder {
+ AVERAGE {
+ @Override
+ public AverageLinkageTree buildTree(HierarchicalAgglomerative h) {
+ return h.new AverageLinkageTree();
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable geo) {
+ return !comp_avg_unsupported.contains(geo.getClass());
+ }
+ },
+
+ COMPLETE {
+ @Override
+ public CompleteLinkageTree buildTree(HierarchicalAgglomerative h) {
+ return h.new CompleteLinkageTree();
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable geo) {
+ return !comp_avg_unsupported.contains(geo.getClass());
+ }
+ },
+
+ WARD {
+ @Override
+ public WardTree buildTree(HierarchicalAgglomerative h) {
+ return h.new WardTree();
+ }
+
+ @Override
+ public boolean isValidMetric(GeometricallySeparable geo) {
+ return geo.equals(Distance.EUCLIDEAN);
+ }
+ };
+ }
+
+
+ @Override
+ final public boolean isValidMetric(GeometricallySeparable geo) {
+ return this.linkage.isValidMetric(geo);
+ }
+
+
+ /**
+ * The number of rows in the matrix
+ */
+ final private int m;
+
+
+ /**
+ * The labels for the clusters
+ */
+ volatile private int[] labels = null;
+ /**
+ * The flattened distance vector
+ */
+ volatile private EfficientDistanceMatrix dist_vec = null;
+ volatile HierarchicalDendrogram tree = null;
+ /**
+ * Volatile because if null will later change during build
+ */
+ volatile private int num_clusters;
+
+
+ protected HierarchicalAgglomerative(RealMatrix data) {
+ this(data, new HierarchicalAgglomerativeParameters());
+ }
+
+ protected HierarchicalAgglomerative(RealMatrix data,
+ HierarchicalAgglomerativeParameters planner) {
+ super(data, planner, planner.getNumClusters());
+ this.linkage = planner.getLinkage();
+
+ if (!isValidMetric(this.dist_metric)) {
+ warn(this.dist_metric.getName() + " is invalid for " + this.linkage +
+ ". Falling back to default Euclidean dist");
+ setSeparabilityMetric(DEF_DIST);
+ }
+
+ this.m = data.getRowDimension();
+ this.num_clusters = super.k;
+
+ logModelSummary();
+ }
+
+ @Override
+ final protected ModelSummary modelSummary() {
+ return new ModelSummary(new Object[]{
+ "Num Rows", "Num Cols", "Metric", "Linkage", "Allow Par.", "Num. Clusters"
+ }, new Object[]{
+ data.getRowDimension(), data.getColumnDimension(),
+ getSeparabilityMetric(), linkage,
+ parallel,
+ num_clusters
+ });
+ }
+
+
+ /**
+ * Computes a flattened upper triangular distance matrix in a much more space efficient manner,
+ * however traversing it requires intermittent calculations using {@link #navigate(int, int, int)}
+ *
+ * @author Taylor G Smith
+ */
+ protected static class EfficientDistanceMatrix implements java.io.Serializable {
+ private static final long serialVersionUID = -7329893729526766664L;
+ final protected double[] dists;
+
+ EfficientDistanceMatrix(final RealMatrix data, GeometricallySeparable dist, boolean partial) {
+ this.dists = build(data.getData(), dist, partial);
+ }
+
+ /**
+ * Copy constructor
+ */
+ /*// not needed right now...
+ private EfficientDistanceMatrix(EfficientDistanceMatrix other) {
+ this.dists = VecUtils.copy(other.dists);
+ }
+ */
+
+ /**
+ * Computes a flattened upper triangular distance matrix in a much more space efficient manner,
+ * however traversing it requires intermittent calculations using {@link #navigateFlattenedMatrix(double[], int, int, int)}
+ *
+ * @param data
+ * @param dist
+ * @param partial -- use the partial distance?
+ * @return a flattened distance vector
+ */
+ static double[] build(final double[][] data, GeometricallySeparable dist, boolean partial) {
+ final int m = data.length;
+ final int s = m * (m - 1) / 2; // The shape of the flattened upper triangular matrix (m choose 2)
+ final double[] vec = new double[s];
+ for (int i = 0, r = 0; i < m - 1; i++)
+ for (int j = i + 1; j < m; j++, r++)
+ vec[r] = partial ? dist.getPartialDistance(data[i], data[j]) :
+ dist.getDistance(data[i], data[j]);
+
+ return vec;
+ }
+
+ /**
+ * For a flattened upper triangular matrix...
+ *
+ *
+ * Original:
+ *
+ *
+ *
0
1
2
3
+ *
0
0
1
2
+ *
0
0
0
1
+ *
0
0
0
0
+ *
+ *
+ *
+ * Flattened:
+ *
+ * <1 2 3 1 2 1>
+ *
+ *
+ * ...and the parameters m, the original row dimension,
+ * i and j, will identify the corresponding index
+ * in the flattened vector such that mat[0][3] corresponds to vec[2];
+ * this method, then, would return 2 (the index in the vector
+ * corresponding to mat[0][3]) in this case.
+ *
+ * @param m
+ * @param i
+ * @param j
+ * @return the corresponding vector index
+ */
+ static int getIndexFromFlattenedVec(final int m, final int i, final int j) {
+ if (i < j)
+ return m * i - (i * (i + 1) / 2) + (j - i - 1);
+ else if (i > j)
+ return m * j - (j * (j + 1) / 2) + (i - j - 1);
+ throw new IllegalArgumentException(i + ", " + j + "; i should not equal j");
+ }
+
+ /**
+ * For a flattened upper triangular matrix...
+ *
+ *
+ * Original:
+ *
+ *
+ *
0
1
2
3
+ *
0
0
1
2
+ *
0
0
0
1
+ *
0
0
0
0
+ *
+ *
+ *
+ * Flattened:
+ *
+ * <1 2 3 1 2 1>
+ *
+ *
+ * ...and the parameters m, the original row dimension,
+ * i and j, will identify the corresponding value
+ * in the flattened vector such that mat[0][3] corresponds to vec[2];
+ * this method, then, would return 3, the value at index 2, in this case.
+ *
+ * @param m
+ * @param i
+ * @param j
+ * @return the corresponding vector index
+ */
+ double navigate(final int m, final int i, final int j) {
+ return dists[getIndexFromFlattenedVec(m, i, j)];
+ }
+ }
+
+ abstract class HierarchicalDendrogram implements java.io.Serializable, NamedEntity {
+ private static final long serialVersionUID = 5295537901834851676L;
+ public final HierarchicalAgglomerative ref;
+ public final GeometricallySeparable dist;
+
+ HierarchicalDendrogram() {
+ ref = HierarchicalAgglomerative.this;
+ dist = ref.getSeparabilityMetric();
+
+ if (null == dist_vec) // why would this happen?
+ dist_vec = new EfficientDistanceMatrix(data, dist, true);
+ }
+
+ double[][] linkage() {
+ // Perform the linkage logic in the tree
+ //EfficientDistanceMatrix y = dist_vec.copy(); // Copy the dist_vec
+
+ double[][] Z = new double[m - 1][4]; // Holding matrix
+ link(dist_vec, Z, m); // Immutabily change Z
+
+ // Final linkage tree out...
+ return MatUtils.getColumns(Z, new int[]{0, 1});
+ }
+
+ private void link(final EfficientDistanceMatrix dists, final double[][] Z, final int n) {
+ int i, j, k, x = -1, y = -1, i_start, nx, ny, ni, id_x, id_y, id_i, c_idx;
+ double current_min;
+
+ // Inter cluster dists
+ EfficientDistanceMatrix D = dists; //VecUtils.copy(dists);
+
+ // Map the indices to node ids
+ ref.info("initializing node mappings (" + getClass().getName().split("\\$")[1] + ")");
+ int[] id_map = new int[n];
+ for (i = 0; i < n; i++)
+ id_map[i] = i;
+
+ LogTimer link_timer = new LogTimer(), iterTimer;
+ int incrementor = n / 10, pct = 1;
+ for (k = 0; k < n - 1; k++) {
+ if (incrementor > 0 && k % incrementor == 0)
+ ref.info("node mapping progress - " + 10 * pct++ + "%. Total link time: " +
+ link_timer.toString() + "");
+
+ // get two closest x, y
+ current_min = Double.POSITIVE_INFINITY;
+
+ iterTimer = new LogTimer();
+ for (i = 0; i < n - 1; i++) {
+ if (id_map[i] == -1)
+ continue;
+
+
+ i_start = EfficientDistanceMatrix.getIndexFromFlattenedVec(n, i, i + 1);
+ for (j = 0; j < n - i - 1; j++) {
+ if (D.dists[i_start + j] < current_min) {
+ current_min = D.dists[i_start + j];
+ x = i;
+ y = i + j + 1;
+ }
+ }
+ }
+
+ id_x = id_map[x];
+ id_y = id_map[y];
+
+ // Get original num points in clusters x,y
+ nx = id_x < n ? 1 : (int) Z[id_x - n][3];
+ ny = id_y < n ? 1 : (int) Z[id_y - n][3];
+
+ // Record new node
+ Z[k][0] = FastMath.min(id_x, id_y);
+ Z[k][1] = FastMath.max(id_y, id_x);
+ Z[k][2] = current_min;
+ Z[k][3] = nx + ny;
+ id_map[x] = -1; // cluster x to be dropped
+ id_map[y] = n + k; // cluster y replaced
+
+ // update dist mat
+ int cont = 0;
+ for (i = 0; i < n; i++) {
+ id_i = id_map[i];
+ if (id_i == -1 || id_i == n + k) {
+ cont++;
+ continue;
+ }
+
+ ni = id_i < n ? 1 : (int) Z[id_i - n][3];
+ c_idx = EfficientDistanceMatrix.getIndexFromFlattenedVec(n, i, y);
+ D.dists[c_idx] = getDist(D.navigate(n, i, x), D.dists[c_idx], current_min, nx, ny, ni);
+
+ if (i < x)
+ D.dists[EfficientDistanceMatrix.getIndexFromFlattenedVec(n, i, x)] = Double.POSITIVE_INFINITY;
+ }
+
+ fitSummary.add(new Object[]{
+ k, current_min, cont, iterTimer.formatTime(),
+ link_timer.formatTime(), link_timer.wallMsg()
+ });
+ }
+ }
+
+ abstract protected double getDist(final double dx, final double dy,
+ final double current_min, final int nx, final int ny, final int ni);
+ }
+
+ class WardTree extends HierarchicalDendrogram {
+ private static final long serialVersionUID = -2336170779406847047L;
+
+ public WardTree() {
+ super();
+ }
+
+ @Override
+ protected double getDist(double dx, double dy,
+ double current_min, int nx, int ny, int ni) {
+
+ final double t = 1.0 / (nx + ny + ni);
+ return FastMath.sqrt((ni + nx) * t * dx * dx +
+ (ni + ny) * t * dy * dy -
+ ni * t * current_min * current_min);
+ }
+
+ @Override
+ public String getName() {
+ return "Ward Tree";
+ }
+ }
+
+ abstract class LinkageTree extends HierarchicalDendrogram {
+ private static final long serialVersionUID = -252115690411913842L;
+
+ public LinkageTree() {
+ super();
+ }
+ }
+
+ class AverageLinkageTree extends LinkageTree {
+ private static final long serialVersionUID = 5891407873391751152L;
+
+ public AverageLinkageTree() {
+ super();
+ }
+
+ @Override
+ protected double getDist(double dx, double dy,
+ double current_min, int nx, int ny, int ni) {
+ return (nx * dx + ny * dy) / (double) (nx + ny);
+ }
+
+ @Override
+ public String getName() {
+ return "Avg Linkage Tree";
+ }
+ }
+
+ class CompleteLinkageTree extends LinkageTree {
+ private static final long serialVersionUID = 7407993870975009576L;
+
+ public CompleteLinkageTree() {
+ super();
+ }
+
+ @Override
+ protected double getDist(double dx, double dy,
+ double current_min, int nx, int ny, int ni) {
+ return FastMath.max(dx, dy);
+ }
+
+ @Override
+ public String getName() {
+ return "Complete Linkage Tree";
+ }
+ }
+
+
+ @Override
+ public String getName() {
+ return "Agglomerative";
+ }
+
+ public Linkage getLinkage() {
+ return linkage;
+ }
+
+ @Override
+ protected HierarchicalAgglomerative fit() {
+ synchronized (fitLock) {
+ if (null != labels) // already fit
+ return this;
+
+ final LogTimer timer = new LogTimer();
+ labels = new int[m];
+
+ /*
+ * Corner case: k = 1 (due to singularity?)
+ */
+ if (1 == k) {
+ this.fitSummary.add(new Object[]{
+ 0, 0, Double.NaN, timer.formatTime(), timer.formatTime(), timer.wallMsg()
+ });
+
+ warn("converged immediately due to " + (this.singular_value ?
+ "singular nature of input matrix" : "k = 1"));
+ sayBye(timer);
+ return this;
+ }
+
+ dist_vec = new EfficientDistanceMatrix(data, getSeparabilityMetric(), true);
+
+ // Log info...
+ info("computed distance matrix in " + timer.toString());
+
+
+ // Get the tree class for logging...
+ LogTimer treeTimer = new LogTimer();
+ this.tree = this.linkage.buildTree(this);
+
+ // Tree build
+ info("constructed " + tree.getName() + " HierarchicalDendrogram in " + treeTimer.toString());
+ double[][] children = tree.linkage();
+
+
+ // Cut the tree
+ labels = hcCut(num_clusters, children, m);
+ labels = new SafeLabelEncoder(labels).fit().getEncodedLabels();
+
+
+ sayBye(timer);
+ dist_vec = null;
+ return this;
+ }
+
+ } // End train
+
+ static int[] hcCut(final int n_clusters, final double[][] children, final int n_leaves) {
+ /*
+ * Leave children as a double[][] despite it
+ * being ints. This will allow VecUtils to operate
+ */
+
+ if (n_clusters > n_leaves)
+ throw new InternalError(n_clusters + " > " + n_leaves);
+
+ // Init nodes
+ SimpleHeap nodes = new SimpleHeap<>(-((int) VecUtils.max(children[children.length - 1]) + 1));
+
+
+ for (int i = 0; i < n_clusters - 1; i++) {
+ int inner_idx = -nodes.get(0) - n_leaves;
+ if (inner_idx < 0)
+ inner_idx = children.length + inner_idx;
+
+ double[] these_children = children[inner_idx];
+ nodes.push(-((int) these_children[0]));
+ nodes.pushPop(-((int) these_children[1]));
+ }
+
+ int i = 0;
+ final int[] labels = new int[n_leaves];
+ for (Integer node : nodes) {
+ Integer[] descendants = hcGetDescendents(-node, children, n_leaves);
+ for (Integer desc : descendants)
+ labels[desc] = i;
+
+ i++;
+ }
+
+ return labels;
+ }
+
+ static Integer[] hcGetDescendents(int node, double[][] children, int leaves) {
+ if (node < leaves)
+ return new Integer[]{node};
+
+ final SimpleHeap ind = new SimpleHeap<>(node);
+ final ArrayList descendent = new ArrayList<>();
+ int i, n_indices = 1;
+
+ while (n_indices > 0) {
+ i = ind.popInPlace();
+ if (i < leaves) {
+ descendent.add(i);
+ n_indices--;
+ } else {
+ final double[] chils = children[i - leaves];
+ for (double d : chils)
+ ind.add((int) d);
+ n_indices++;
+ }
+ }
+
+ return descendent.toArray(new Integer[descendent.size()]);
+ }
+
+ @Override
+ public int[] getLabels() {
+ return super.handleLabelCopy(labels);
+ }
+
+
+ @Override
+ public Algo getLoggerTag() {
+ return com.clust4j.log.Log.Tag.Algo.AGGLOMERATIVE;
+ }
+
+ @Override
+ final protected Object[] getModelFitSummaryHeaders() {
+ return new Object[]{
+ "Link Iter. #", "Iter. Min", "Continues", "Iter. Time", "Total Time", "Wall"
+ };
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public double indexAffinityScore(int[] labels) {
+ // Propagates ModelNotFitException
+ return SupervisedMetric.INDEX_AFFINITY.evaluate(labels, getLabels());
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public double silhouetteScore() {
+ // Propagates ModelNotFitException
+ return SILHOUETTE.evaluate(this, getLabels());
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public int[] predict(RealMatrix newData) {
+ final int[] fit_labels = getLabels(); // throws the MNF exception if not fit
+ final int numSamples = newData.getRowDimension(), n = newData.getColumnDimension();
+
+ // Make sure matches dimensionally
+ if (n != this.data.getColumnDimension())
+ throw new DimensionMismatchException(n, data.getColumnDimension());
+
+ /*
+ * There's no great way to predict on a hierarchical
+ * algorithm, so we'll treat this like a CentroidLearner,
+ * create centroids from the k clusters formed, then
+ * predict via the CentroidUtils. This works because
+ * Hierarchical is not a NoiseyClusterer
+ */
+
+ // CORNER CASE: num_clusters == 1, return only label (0)
+ if (1 == num_clusters)
+ return VecUtils.repInt(fit_labels[0], numSamples);
+
+ return new NearestCentroidParameters()
+ .setMetric(this.dist_metric) // if it fails, falls back to default Euclidean...
+ .setVerbose(false) // just to be sure in case default ever changes...
+ .fitNewModel(this.getData(), fit_labels)
+ .predict(newData);
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/HierarchicalAgglomerativeParameters.java b/src/main/java/com/clust4j/algo/HierarchicalAgglomerativeParameters.java
new file mode 100644
index 00000000..9dd32697
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/HierarchicalAgglomerativeParameters.java
@@ -0,0 +1,103 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.algo.HierarchicalAgglomerative.Linkage;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import java.util.Random;
+
+final public class HierarchicalAgglomerativeParameters
+ extends BaseClustererParameters
+ implements UnsupervisedClassifierParameters {
+
+ private static final long serialVersionUID = -1333222392991867085L;
+ private static int DEF_K = 2;
+ private Linkage linkage = HierarchicalAgglomerative.DEF_LINKAGE;
+ private int num_clusters = DEF_K;
+
+ public HierarchicalAgglomerativeParameters() {
+ this(DEF_K);
+ }
+
+ public HierarchicalAgglomerativeParameters(int k) {
+ this.num_clusters = k;
+ }
+
+ public HierarchicalAgglomerativeParameters(Linkage linkage) {
+ this();
+ this.linkage = linkage;
+ }
+
+ @Override
+ public HierarchicalAgglomerative fitNewModel(RealMatrix data) {
+ return new HierarchicalAgglomerative(data, this.copy()).fit();
+ }
+
+ @Override
+ public HierarchicalAgglomerativeParameters copy() {
+ return new HierarchicalAgglomerativeParameters(linkage)
+ .setMetric(metric)
+ .setSeed(seed)
+ .setVerbose(verbose)
+ .setNumClusters(num_clusters)
+ .setForceParallel(parallel);
+ }
+
+ public Linkage getLinkage() {
+ return linkage;
+ }
+
+ public HierarchicalAgglomerativeParameters setLinkage(Linkage l) {
+ this.linkage = l;
+ return this;
+ }
+
+ public int getNumClusters() {
+ return num_clusters;
+ }
+
+ public HierarchicalAgglomerativeParameters setNumClusters(final int d) {
+ this.num_clusters = d;
+ return this;
+ }
+
+ @Override
+ public HierarchicalAgglomerativeParameters setForceParallel(boolean b) {
+ this.parallel = b;
+ return this;
+ }
+
+ @Override
+ public HierarchicalAgglomerativeParameters setSeed(final Random seed) {
+ this.seed = seed;
+ return this;
+ }
+
+ @Override
+ public HierarchicalAgglomerativeParameters setVerbose(boolean b) {
+ this.verbose = b;
+ return this;
+ }
+
+ @Override
+ public HierarchicalAgglomerativeParameters setMetric(GeometricallySeparable dist) {
+ this.metric = dist;
+ return this;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/KDTree.java b/src/main/java/com/clust4j/algo/KDTree.java
new file mode 100644
index 00000000..c34bb9e3
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/KDTree.java
@@ -0,0 +1,297 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.log.Loggable;
+import com.clust4j.metrics.pairwise.Distance;
+import com.clust4j.metrics.pairwise.DistanceMetric;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.metrics.pairwise.MinkowskiDistance;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.HashSet;
+
+/**
+ * A k-d tree (short for k-dimensional tree) is a space-partitioning
+ * data structure for organizing points in a k-dimensional space. k-d
+ * trees are a useful data structure for several applications, such as searches
+ * involving a multidimensional search key (e.g. range searches and nearest
+ * neighbor searches). k-d trees are a special case of binary space partitioning trees.
+ *
+ * @author Taylor G Smith
+ * @see NearestNeighborHeapSearch
+ * @see k-d trees
+ */
+public class KDTree extends NearestNeighborHeapSearch {
+ private static final long serialVersionUID = -3744545394278454548L;
+ public final static HashSet> VALID_METRICS;
+
+ static {
+ VALID_METRICS = new HashSet<>();
+ VALID_METRICS.add(Distance.EUCLIDEAN.getClass());
+ VALID_METRICS.add(Distance.MANHATTAN.getClass());
+ VALID_METRICS.add(MinkowskiDistance.class);
+ VALID_METRICS.add(Distance.CHEBYSHEV.getClass());
+ }
+
+
+ @Override
+ boolean checkValidDistMet(GeometricallySeparable dist) {
+ return VALID_METRICS.contains(dist.getClass());
+ }
+
+
+ public KDTree(final RealMatrix X) {
+ super(X);
+ }
+
+ public KDTree(final RealMatrix X, int leaf_size) {
+ super(X, leaf_size);
+ }
+
+ public KDTree(final RealMatrix X, DistanceMetric dist) {
+ super(X, dist);
+ }
+
+ public KDTree(final RealMatrix X, Loggable logger) {
+ super(X, logger);
+ }
+
+ public KDTree(final RealMatrix X, int leaf_size, DistanceMetric dist) {
+ super(X, leaf_size, dist);
+ }
+
+ public KDTree(final RealMatrix X, int leaf_size, DistanceMetric dist, Loggable logger) {
+ super(X, leaf_size, dist, logger);
+ }
+
+ protected KDTree(final double[][] X, int leaf_size, DistanceMetric dist, Loggable logger) {
+ super(X, leaf_size, dist, logger);
+ }
+
+ /**
+ * Constructor with logger and distance metric
+ *
+ * @param X
+ * @param dist
+ * @param logger
+ */
+ public KDTree(final RealMatrix X, DistanceMetric dist, Loggable logger) {
+ super(X, dist, logger);
+ }
+
+
+ @Override
+ void allocateData(NearestNeighborHeapSearch tree, int n_nodes, int n_features) {
+ tree.node_bounds = new double[2][n_nodes][n_features];
+ }
+
+ @Override
+ void initNode(NearestNeighborHeapSearch tree, int i_node, int idx_start, int idx_end) {
+ int n_features = tree.N_FEATURES, i, j;
+ double rad = 0;
+
+ double[] lowerBounds = tree.node_bounds[0][i_node];
+ double[] upperBounds = tree.node_bounds[1][i_node];
+ double[][] data = tree.data_arr;
+ int[] idx_array = tree.idx_array;
+ double[] data_row;
+
+ // Get node bounds
+ for (j = 0; j < n_features; j++) {
+ lowerBounds[j] = Double.POSITIVE_INFINITY;
+ upperBounds[j] = Double.NEGATIVE_INFINITY;
+ }
+
+ // Compute data range
+ for (i = idx_start; i < idx_end; i++) {
+ data_row = data[idx_array[i]];
+
+ for (j = 0; j < n_features; j++) {
+ lowerBounds[j] = FastMath.min(lowerBounds[j], data_row[j]);
+ upperBounds[j] = FastMath.max(upperBounds[j], data_row[j]);
+ }
+
+ // The python code does not increment up to the range boundary,
+ // the java for loop does. So we must decrement j by one.
+ j--;
+
+ if (tree.infinity_dist)
+ rad = FastMath.max(rad, 0.5 * (upperBounds[j] - lowerBounds[j]));
+ else
+ rad += FastMath.pow(
+ 0.5 * FastMath.abs(upperBounds[j] - lowerBounds[j]),
+ tree.dist_metric.getP());
+ }
+
+ tree.node_data[i_node].idx_start = idx_start;
+ tree.node_data[i_node].idx_end = idx_end;
+
+ // radius assignment
+ tree.node_data[i_node].radius = Math.pow(rad, 1.0 / tree.dist_metric.getP());
+ }
+
+ @Override
+ final KDTree newInstance(double[][] arr, int leaf, DistanceMetric dist, Loggable logger) {
+ return new KDTree(new Array2DRowRealMatrix(arr, false), leaf, dist, logger);
+ }
+
+ @Override
+ double minDist(NearestNeighborHeapSearch tree, int i_node, double[] pt) {
+ double d = minRDist(tree, i_node, pt);
+ return tree.dist_metric.partialDistanceToDistance(d);
+ }
+
+ @Override
+ double minDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2, int iNode2) {
+ return tree1.dist_metric.partialDistanceToDistance(minRDistDual(tree1, iNode1, tree2, iNode2));
+ }
+
+ @Override
+ double minRDist(NearestNeighborHeapSearch tree, int i_node, double[] pt) {
+ double d_lo, d_hi, d, rdist = 0.0, p = tree.dist_metric.getP();
+ final boolean inf = tree.infinity_dist;
+
+ for (int j = 0; j < N_FEATURES; j++) {
+ d_lo = tree.node_bounds[0][i_node][j] - pt[j];
+ d_hi = pt[j] - tree.node_bounds[1][i_node][j];
+ d = (d_lo + FastMath.abs(d_lo)) + (d_hi + FastMath.abs(d_hi));
+
+ rdist = inf ? FastMath.max(rdist, 0.5 * d) :
+ rdist + FastMath.pow(0.5 * d, p);
+ }
+
+ return rdist;
+ }
+
+ @Override
+ double minRDistDual(NearestNeighborHeapSearch tree1, int i_node1, NearestNeighborHeapSearch tree2, int i_node2) {
+ double d, d1, d2, rdist = 0.0, p = tree1.dist_metric.getP();
+ int j, n_features = tree1.N_FEATURES;
+ boolean inf = tree1.infinity_dist;
+
+ for (j = 0; j < n_features; j++) {
+ d1 = (tree1.node_bounds[0][i_node1][j] - tree2.node_bounds[1][i_node2][j]);
+ d2 = (tree2.node_bounds[0][i_node2][j] - tree1.node_bounds[1][i_node1][j]);
+ d = (d1 + FastMath.abs(d1)) + (d2 + FastMath.abs(d2));
+ rdist = inf ? FastMath.max(rdist, 0.5 * d) :
+ rdist + FastMath.pow(0.5 * d, p);
+ }
+
+ return rdist;
+ }
+
+ @Override
+ double maxDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2, int iNode2) {
+ return tree1.dist_metric.partialDistanceToDistance(maxRDistDual(tree1, iNode1, tree2, iNode2));
+ }
+
+ /*
+ @Override
+ double maxDist(NearestNeighborHeapSearch tree, int i_node, double[] pt) {
+ double d = maxRDist(tree, i_node, pt);
+ return tree.dist_metric.partialDistanceToDistance(d);
+ }
+
+ @Override
+ double maxRDist(NearestNeighborHeapSearch tree, int i_node, double[] pt) {
+ double d_lo, d_hi, rdist = 0.0, p = tree.dist_metric.getP();
+ boolean inf = tree.infinity_dist;
+ int n_features = tree.N_FEATURES;
+
+ if(inf) {
+ for(int j = 0; j < n_features; j++) {
+ rdist = FastMath.max(rdist, FastMath.abs(pt[j] - tree.node_bounds[0][i_node][j]));
+ rdist = FastMath.max(rdist, FastMath.abs(pt[j] - tree.node_bounds[1][i_node][j]));
+ }
+ } else {
+ for(int j = 0; j < n_features; j++) {
+ d_lo = FastMath.abs(pt[j] - tree.node_bounds[0][i_node][j]);
+ d_hi = FastMath.abs(pt[j] - tree.node_bounds[1][i_node][j]);
+ rdist += FastMath.pow(FastMath.max(d_lo, d_hi), p);
+ }
+ }
+
+ return rdist;
+ }
+ */
+
+ @Override
+ double maxRDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2, int iNode2) {
+ double d1, d2, rdist = 0.0, p = tree1.dist_metric.getP();
+ int j, n_features = tree1.N_FEATURES;
+ final boolean inf = tree1.infinity_dist;
+
+ if (inf) {
+ for (j = 0; j < n_features; j++) {
+ rdist = FastMath.max(rdist,
+ FastMath.abs(tree1.node_bounds[0][iNode1][j]
+ - tree2.node_bounds[1][iNode2][j]));
+ rdist = FastMath.max(rdist,
+ FastMath.abs(tree1.node_bounds[1][iNode1][j]
+ - tree2.node_bounds[0][iNode2][j]));
+ }
+ } else {
+ for (j = 0; j < n_features; j++) {
+ d1 = FastMath.abs(tree1.node_bounds[0][iNode1][j]
+ - tree2.node_bounds[1][iNode2][j]);
+ d2 = FastMath.abs(tree1.node_bounds[1][iNode1][j]
+ - tree2.node_bounds[0][iNode2][j]);
+ rdist += FastMath.pow(FastMath.max(d1, d2), p);
+ }
+ }
+
+ return rdist;
+ }
+
+
+ @Override
+ void minMaxDist(NearestNeighborHeapSearch tree, int i_node, double[] pt, MutableDouble minDist, MutableDouble maxDist) {
+ double d, d_lo, d_hi, p = tree.dist_metric.getP();
+ int j, n_features = tree.N_FEATURES;
+ boolean inf = tree.infinity_dist;
+
+ minDist.value = 0.0;
+ maxDist.value = 0.0;
+
+ for (j = 0; j < n_features; j++) {
+ d_lo = tree.node_bounds[0][i_node][j] - pt[j];
+ d_hi = pt[j] - tree.node_bounds[1][i_node][j];
+ d = (d_lo + FastMath.abs(d_lo)) + (d_hi + FastMath.abs(d_hi));
+
+ if (inf) {
+ minDist.value = FastMath.max(minDist.value, 0.5 * d);
+ maxDist.value = FastMath.max(maxDist.value,
+ FastMath.abs(pt[j] - tree.node_bounds[0][i_node][j]));
+ maxDist.value = FastMath.max(maxDist.value,
+ FastMath.abs(pt[j] - tree.node_bounds[1][i_node][j]));
+ } else {
+ minDist.value += FastMath.pow(0.5 * d, p);
+ maxDist.value += FastMath.pow(
+ FastMath.max(FastMath.abs(d_lo), FastMath.abs(d_hi)), p);
+ }
+ }
+
+
+ if (!inf) {
+ double pow = 1.0 / p;
+ minDist.value = FastMath.pow(minDist.value, pow);
+ maxDist.value = FastMath.pow(maxDist.value, pow);
+ }
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/KMeans.java b/src/main/java/com/clust4j/algo/KMeans.java
new file mode 100644
index 00000000..9e3e9737
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/KMeans.java
@@ -0,0 +1,267 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.except.NaNException;
+import com.clust4j.log.Log.Tag.Algo;
+import com.clust4j.log.LogTimer;
+import com.clust4j.metrics.pairwise.Distance;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.utils.EntryPair;
+import com.clust4j.utils.VecUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.ArrayList;
+import java.util.TreeMap;
+
+/**
+ * KMeans clustering is
+ * a method of vector quantization, originally from signal processing, that is popular
+ * for cluster analysis in data mining. KMeans clustering aims to partition m
+ * observations into k clusters in which each observation belongs to the cluster
+ * with the nearest mean, serving as a prototype of the cluster. This results in
+ * a partitioning of the data space into Voronoi cells.
+ *
+ * @author Taylor G Smith <tgsmith61591@gmail.com>
+ */
+final public class KMeans extends AbstractCentroidClusterer {
+ private static final long serialVersionUID = 1102324012006818767L;
+ final public static GeometricallySeparable DEF_DIST = Distance.EUCLIDEAN;
+ final public static int DEF_MAX_ITER = 100;
+
+
+ protected KMeans(final RealMatrix data) {
+ this(data, DEF_K);
+ }
+
+ protected KMeans(final RealMatrix data, final int k) {
+ this(data, new KMeansParameters(k));
+ }
+
+ protected KMeans(final RealMatrix data, final KMeansParameters planner) {
+ super(data, planner);
+ }
+
+
+ @Override
+ public String getName() {
+ return "KMeans";
+ }
+
+ @Override
+ protected KMeans fit() {
+ synchronized (fitLock) {
+
+ if (null != labels) // already fit
+ return this;
+
+
+ final LogTimer timer = new LogTimer();
+ final double[][] X = data.getData();
+ final int n = data.getColumnDimension();
+ final double nan = Double.NaN;
+
+
+ // Corner case: K = 1 or all singular values
+ if (1 == k) {
+ labelFromSingularK(X);
+ fitSummary.add(new Object[]{iter, converged, tss, tss, nan, timer.wallTime()});
+ sayBye(timer);
+ return this;
+ }
+
+
+ // Nearest centroid model to predict labels
+ NearestCentroid model = null;
+ EntryPair label_dist;
+
+
+ // Keep track of TSS (sum of barycentric distances)
+ double last_wss_sum = Double.POSITIVE_INFINITY, wss_sum = 0;
+ ArrayList new_centroids;
+
+ for (iter = 0; iter < maxIter; iter++) {
+
+ // Get labels for nearest centroids
+ try {
+ model = new NearestCentroid(CentroidUtils.centroidsToMatrix(centroids, false),
+ VecUtils.arange(k), new NearestCentroidParameters()
+ .setSeed(getSeed())
+ .setMetric(getSeparabilityMetric())
+ .setVerbose(false)).fit();
+ } catch (NaNException NaN) {
+ /*
+ * If they metric used produces lots of infs or -infs, it
+ * makes it hard if not impossible to effectively segment the
+ * input space. Thus, the centroid assignment portion below can
+ * yield a zero count (denominator) for one or more of the centroids
+ * which makes the entire row NaN. We should tell the user to
+ * try a different metric, if that's the case.
+ *
+ error(new IllegalClusterStateException(dist_metric.getName()+" produced an entirely " +
+ "infinite distance matrix, making it difficult to segment the input space. Try a different " +
+ "metric."));
+ */
+ this.k = 1;
+ warn("(dis)similarity metric (" + dist_metric + ") cannot partition space without propagating Infs. Returning one cluster");
+
+ labelFromSingularK(X);
+ fitSummary.add(new Object[]{iter, converged, tss, tss, nan, timer.wallTime()});
+ sayBye(timer);
+ return this;
+ }
+
+ label_dist = model.predict(X);
+
+ // unpack the EntryPair
+ labels = label_dist.getKey();
+ new_centroids = new ArrayList<>(k);
+
+
+ int label;
+ wss = new double[k];
+ int[] centroid_counts = new int[k];
+ double[] centroid;
+ double[][] new_centroid_arrays = new double[k][n];
+ for (int i = 0; i < m; i++) {
+ label = labels[i];
+ centroid = centroids.get(label);
+
+ // increment count for this centroid
+ double this_cost = 0;
+ centroid_counts[label]++;
+ for (int j = 0; j < centroid.length; j++) {
+ double diff = X[i][j] - centroid[j];
+ this_cost += (diff * diff);
+
+ // Add the the centroid sums
+ new_centroid_arrays[label][j] += X[i][j];
+ }
+
+ // add this cost to the WSS
+ wss[label] += this_cost;
+ }
+
+ // one pass of K for some consolidation
+ wss_sum = 0;
+ for (int i = 0; i < k; i++) {
+ wss_sum += wss[i];
+
+ for (int j = 0; j < n; j++) // meanify
+ new_centroid_arrays[i][j] /= (double) centroid_counts[i];
+
+ new_centroids.add(new_centroid_arrays[i]);
+ }
+
+ // update the BSS
+ bss = tss - wss_sum;
+
+
+ // Assign new centroids
+ double diff = last_wss_sum - wss_sum;
+ last_wss_sum = wss_sum;
+
+
+ // Check for convergence and add summary:
+ converged = FastMath.abs(diff) < tolerance; // first iter will be inf
+ fitSummary.add(new Object[]{
+ converged ? iter++ : iter,
+ converged,
+ tss, wss_sum, bss,
+ timer.wallTime()});
+
+ if (converged) {
+ break;
+ } else {
+ // otherwise, reassign centroids
+ centroids = new_centroids;
+ }
+
+ } // end iterations
+
+
+ // Reorder the labels, centroids and wss indices
+ reorderLabelsAndCentroids();
+
+ if (!converged)
+ warn("algorithm did not converge");
+
+
+ // wrap things up, create summary..
+ sayBye(timer);
+
+
+ return this;
+ }
+
+ }
+
+
+ @Override
+ public Algo getLoggerTag() {
+ return com.clust4j.log.Log.Tag.Algo.KMEANS;
+ }
+
+ @Override
+ protected Object[] getModelFitSummaryHeaders() {
+ return new Object[]{
+ "Iter. #", "Converged", "TSS", "WSS", "BSS", "Wall"
+ };
+ }
+
+ /**
+ * Reorder the labels in order of appearance using the
+ * {@link LabelEncoder}. Also reorder the centroids to correspond
+ * with new label order
+ */
+ @Override
+ protected void reorderLabelsAndCentroids() {
+ boolean wss_null = null == wss;
+
+ /*
+ * reorder labels...
+ */
+ final LabelEncoder encoder = new LabelEncoder(labels).fit();
+ labels = encoder.getEncodedLabels();
+
+ // also reorder centroids... takes O(2K) passes
+ TreeMap tmpCentroids = new TreeMap<>();
+ double[] new_wss = new double[k];
+
+ /*
+ * We have to be delicate about this--KMedoids stores
+ * labels as indices pointing to which record is the medoid,
+ * whereas KMeans uses 0 thru K. Thus we can simply index in
+ * KMeans, but will get an IndexOOB exception in Kmedoids, so
+ * we need to come up with a universal solution which might
+ * look ugly at a glance, but is robust to both.
+ */
+ int encoded;
+ for (int i = 0; i < k; i++) {
+ encoded = encoder.reverseEncodeOrNull(i);
+ tmpCentroids.put(i, centroids.get(encoded));
+
+ new_wss[i] = wss_null ? Double.NaN : wss[encoded];
+ }
+
+ for (int i = 0; i < k; i++)
+ centroids.set(i, tmpCentroids.get(i));
+
+ // reset wss
+ this.wss = new_wss;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/KMeansParameters.java b/src/main/java/com/clust4j/algo/KMeansParameters.java
new file mode 100644
index 00000000..27fef154
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/KMeansParameters.java
@@ -0,0 +1,105 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.algo.AbstractCentroidClusterer.InitializationStrategy;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import java.util.Random;
+
+final public class KMeansParameters extends CentroidClustererParameters {
+ private static final long serialVersionUID = -813106538623499760L;
+
+ private InitializationStrategy strat = KMeans.DEF_INIT;
+ private int maxIter = KMeans.DEF_MAX_ITER;
+
+ public KMeansParameters() {
+ }
+
+ public KMeansParameters(int k) {
+ this.k = k;
+ }
+
+ @Override
+ public KMeans fitNewModel(final RealMatrix data) {
+ return new KMeans(data, this.copy()).fit();
+ }
+
+ @Override
+ public KMeansParameters copy() {
+ return new KMeansParameters(k)
+ .setMaxIter(maxIter)
+ .setConvergenceCriteria(minChange)
+ .setMetric(metric)
+ .setVerbose(verbose)
+ .setSeed(seed)
+ .setInitializationStrategy(strat)
+ .setForceParallel(parallel);
+ }
+
+ @Override
+ public InitializationStrategy getInitializationStrategy() {
+ return strat;
+ }
+
+ @Override
+ public int getMaxIter() {
+ return maxIter;
+ }
+
+ @Override
+ public KMeansParameters setForceParallel(boolean b) {
+ this.parallel = b;
+ return this;
+ }
+
+ @Override
+ public KMeansParameters setMetric(final GeometricallySeparable dist) {
+ this.metric = dist;
+ return this;
+ }
+
+ public KMeansParameters setMaxIter(final int max) {
+ this.maxIter = max;
+ return this;
+ }
+
+ @Override
+ public KMeansParameters setConvergenceCriteria(final double min) {
+ this.minChange = min;
+ return this;
+ }
+
+ @Override
+ public KMeansParameters setInitializationStrategy(InitializationStrategy init) {
+ this.strat = init;
+ return this;
+ }
+
+ @Override
+ public KMeansParameters setSeed(final Random seed) {
+ this.seed = seed;
+ return this;
+ }
+
+ @Override
+ public KMeansParameters setVerbose(final boolean v) {
+ this.verbose = v;
+ return this;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/KMedoids.java b/src/main/java/com/clust4j/algo/KMedoids.java
new file mode 100644
index 00000000..f2b96a7a
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/KMedoids.java
@@ -0,0 +1,499 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.except.IllegalClusterStateException;
+import com.clust4j.log.Log.Tag.Algo;
+import com.clust4j.log.LogTimer;
+import com.clust4j.metrics.pairwise.Distance;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.metrics.pairwise.Pairwise;
+import com.clust4j.utils.VecUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.TreeMap;
+
+/**
+ * KMedoids is
+ * a clustering algorithm related to the {@link KMeans} algorithm and the
+ * medoidshift algorithm. Both the KMeans and KMedoids algorithms are
+ * partitional (breaking the dataset up into groups) and both attempt
+ * to minimize the distance between points labeled to be in a cluster
+ * and a point designated as the center of that cluster. In contrast to
+ * the KMeans algorithm, KMedoids chooses datapoints as centers (medoids
+ * or exemplars) and works with an arbitrary matrix of distances between
+ * datapoints instead of Euclidean distance (l2 norm). This method was proposed in
+ * 1987 for the work with Manhattan distance (l1 norm) and other distances.
+ *
+ *
+ * clust4j utilizes the
+ * Voronoi iteration technique to identify clusters. Alternative greedy searches,
+ * including PAM (partitioning around medoids), are faster yet may not find the optimal
+ * solution. For this reason, clust4j's implementation of KMedoids almost always surpasses
+ * the performance of {@link KMeans}, however it can typically take longer as well.
+ *
+ * @author Taylor G Smith <tgsmith61591@gmail.com>
+ * @see {@link AbstractPartitionalClusterer}
+ */
+final public class KMedoids extends AbstractCentroidClusterer {
+
+ /**
+ *
+ */
+ private static final long serialVersionUID = -4468316488158880820L;
+ final public static GeometricallySeparable DEF_DIST = Distance.MANHATTAN;
+ final public static int DEF_MAX_ITER = 10;
+
+ /**
+ * Stores the indices of the current medoids. Each index,
+ * 0 thru k-1, corresponds to the class label for the cluster.
+ */
+ volatile private int[] medoid_indices = new int[k];
+
+ /**
+ * Upper triangular, M x M matrix denoting distances between records.
+ * Is only populated during training phase and then set to null for
+ * garbage collection, as a large-M matrix has a high space footprint: O(N^2).
+ * This is only needed during training and then can safely be collected
+ * to free up heap space.
+ */
+ volatile private double[][] dist_mat = null;
+
+ /**
+ * Map the index to the WSS
+ */
+ volatile private TreeMap med_to_wss = new TreeMap<>();
+
+
+ protected KMedoids(final RealMatrix data) {
+ this(data, DEF_K);
+ }
+
+ protected KMedoids(final RealMatrix data, final int k) {
+ this(data, new KMedoidsParameters(k).setMetric(Distance.MANHATTAN));
+ }
+
+ protected KMedoids(final RealMatrix data, final KMedoidsParameters planner) {
+ super(data, planner);
+
+ // Check if is Manhattan
+ if (!this.dist_metric.equals(Distance.MANHATTAN)) {
+ warn("KMedoids is intented to run with Manhattan distance, WSS/BSS computations will be inaccurate");
+ //this.dist_metric = Distance.MANHATTAN; // idk that we want to enforce this...
+ }
+ }
+
+
+ @Override
+ public String getName() {
+ return "KMedoids";
+ }
+
+ @Override
+ protected KMedoids fit() {
+ synchronized (fitLock) {
+
+ if (null != labels) // already fit
+ return this;
+
+ final LogTimer timer = new LogTimer();
+ final double[][] X = data.getData();
+ final double nan = Double.NaN;
+
+
+ // Corner case: K = 1 or all singular
+ if (1 == k) {
+ labelFromSingularK(X);
+ fitSummary.add(new Object[]{iter, converged,
+ tss, // tss
+ tss, // avg per cluster
+ tss, // wss
+ nan, // bss (none)
+ timer.wallTime()});
+ sayBye(timer);
+ return this;
+ }
+
+
+ // We do this in KMedoids and not KMeans, because KMedoids uses
+ // real points as medoids and not means for centroids, thus
+ // the recomputation of distances is unnecessary with the dist mat
+ dist_mat = Pairwise.getDistance(X, getSeparabilityMetric(), true, false);
+ info("distance matrix computed in " + timer.toString());
+
+ // Initialize labels
+ medoid_indices = init_centroid_indices;
+
+
+ ClusterAssignments clusterAssignments;
+ MedoidReassignmentHandler rassn;
+ int[] newMedoids = medoid_indices;
+
+ // Cost vars
+ double bestCost = Double.POSITIVE_INFINITY,
+ maxCost = Double.NEGATIVE_INFINITY,
+ avgCost = Double.NaN, wss_sum = nan;
+
+
+ // Iterate while the cost decreases:
+ boolean convergedFromCost = false; // from cost or system changes?
+ boolean configurationChanged = true;
+ while (configurationChanged
+ && iter < maxIter) {
+
+ /*
+ * 1. In each cluster, make the point that minimizes
+ * the sum of distances within the cluster the medoid
+ */
+ try {
+ clusterAssignments = assignClosestMedoid(newMedoids);
+ } catch (IllegalClusterStateException ouch) {
+ exitOnBadDistanceMetric(X, timer);
+ return this;
+ }
+
+
+ /*
+ * 1.5 The entries are not 100% equal, so we can (re)assign medoids...
+ */
+ try {
+ rassn = new MedoidReassignmentHandler(clusterAssignments);
+ } catch (IllegalClusterStateException ouch) {
+ exitOnBadDistanceMetric(X, timer);
+ return this;
+ }
+
+ /*
+ * 1.75 This happens in the case of bad kernels that cause
+ * infinities to propagate... we can't segment the input
+ * space and need to just return a single cluster.
+ */
+ if (rassn.new_clusters.size() == 1) {
+ this.k = 1;
+ warn("(dis)similarity metric cannot partition space without propagating Infs. Returning one cluster");
+
+ labelFromSingularK(X);
+ fitSummary.add(new Object[]{iter, converged,
+ tss, // tss
+ tss, // avg per cluster
+ tss, // wss
+ nan, // bss (none)
+ timer.wallTime()});
+ sayBye(timer);
+ return this;
+ }
+
+
+ /*
+ * 2. Reassign each point to the cluster defined by the
+ * closest medoid determined in the previous step.
+ */
+ newMedoids = rassn.reassignedMedoidIdcs;
+
+
+ /*
+ * 2.5 Determine whether configuration changed
+ */
+ boolean lastIteration = VecUtils.equalsExactly(newMedoids, medoid_indices);
+
+
+ /*
+ * 3. Update the costs
+ */
+ converged = lastIteration || (convergedFromCost = FastMath.abs(wss_sum - bestCost) < tolerance);
+ double tmp_wss_sum = rassn.new_clusters.total_cst;
+ double tmp_bss = tss - tmp_wss_sum;
+
+ // Check whether greater than max
+ if (tmp_wss_sum > maxCost)
+ maxCost = tmp_wss_sum;
+
+ if (tmp_wss_sum < bestCost) {
+ bestCost = wss_sum = tmp_wss_sum;
+ labels = rassn.new_clusters.assn; // will be medoid idcs until encoded at end
+ med_to_wss = rassn.new_clusters.costs;
+ centroids = rassn.centers;
+ medoid_indices = newMedoids;
+ bss = tmp_bss;
+
+ // get avg cost
+ avgCost = wss_sum / (double) k;
+ }
+
+ if (converged) {
+ reorderLabelsAndCentroids();
+ }
+
+ /*
+ * 3.5 If this is the last one, it'll show the wss and bss
+ */
+ fitSummary.add(new Object[]{iter,
+ converged,
+ tss,
+ avgCost,
+ wss_sum,
+ bss,
+ timer.wallTime()
+ });
+
+
+ iter++;
+ configurationChanged = !converged;
+ }
+
+ if (!converged)
+ warn("algorithm did not converge");
+ else
+ info("algorithm converged due to " +
+ (convergedFromCost ? "cost minimization" : "harmonious state"));
+
+
+ // wrap things up, create summary..
+ sayBye(timer);
+
+ return this;
+ }
+
+ } // End train
+
+
+ /**
+ * Some metrics produce entirely equal dist matrices...
+ */
+ private void exitOnBadDistanceMetric(double[][] X, LogTimer timer) {
+ warn("distance metric (" + dist_metric + ") produced entirely equal distances");
+ labelFromSingularK(X);
+ fitSummary.add(new Object[]{iter, converged, tss, tss, tss, Double.NaN, Double.NaN, timer.wallTime()});
+ sayBye(timer);
+ }
+
+
+ private ClusterAssignments assignClosestMedoid(int[] medoidIdcs) {
+ double minDist;
+ boolean all_tied = true;
+ int nearest, rowIdx, colIdx;
+ final int[] assn = new int[m];
+ final double[] costs = new double[m];
+ for (int i = 0; i < m; i++) {
+ boolean is_a_medoid = false;
+ minDist = Double.POSITIVE_INFINITY;
+
+ /*
+ * The dist_mat is already computed. We just need to traverse
+ * the upper triangular matrix and identify which corresponding
+ * minimum distance per record.
+ */
+ nearest = -1;
+ for (int medoid : medoidIdcs) {
+
+ // Corner case: i is a medoid
+ if (i == medoid) {
+ nearest = medoid;
+ minDist = dist_mat[i][i];
+ is_a_medoid = true;
+ break;
+ }
+
+ rowIdx = FastMath.min(i, medoid);
+ colIdx = FastMath.max(i, medoid);
+
+ if (dist_mat[rowIdx][colIdx] < minDist) {
+ minDist = dist_mat[rowIdx][colIdx];
+ nearest = medoid;
+ }
+ }
+
+ /*
+ * If all of the distances are equal, we can end up with a -1 idx...
+ */
+ if (-1 == nearest)
+ nearest = medoidIdcs[getSeed().nextInt(k)]; // select random nearby
+ if (!is_a_medoid)
+ all_tied = false;
+
+
+ assn[i] = nearest;
+ costs[i] = minDist;
+ }
+
+
+ /*
+ * If everything is tied, we need to bail. Shouldn't happen, now
+ * that we explicitly check earlier on... but we can just label from
+ * a singular K at this point.
+ */
+ if (all_tied) {
+ throw new IllegalClusterStateException("entirely "
+ + "stochastic process: all distances are equal");
+ }
+
+ return new ClusterAssignments(assn, costs);
+ }
+
+
+ /**
+ * Handles medoids reassignments and cost minimizations.
+ * In the Voronoi iteration algorithm, after we've identified the new
+ * cluster assignment, for each cluster, we select the medoid which minimized
+ * intra-cluster variance. Theoretically, this could result in a re-org of clusters,
+ * so we use the new medoid indices to create a new {@link ClusterAssignments} object
+ * as the last step. If the cost does not change in the last step, we know we've
+ * reached convergence.
+ *
+ * @author Taylor G Smith
+ */
+ private class MedoidReassignmentHandler {
+ final ClusterAssignments init_clusters;
+ final ArrayList centers = new ArrayList(k);
+ final int[] reassignedMedoidIdcs = new int[k];
+
+ // Holds the costs of each cluster in order
+ final ClusterAssignments new_clusters;
+
+ /**
+ * Def constructor
+ *
+ * @param assn - new medoid assignments
+ */
+ MedoidReassignmentHandler(ClusterAssignments assn) {
+ this.init_clusters = assn;
+ medoidAssn();
+ this.new_clusters = assignClosestMedoid(reassignedMedoidIdcs);
+ }
+
+ void medoidAssn() {
+ ArrayList members;
+
+ int i = 0;
+ for (Map.Entry> pair : init_clusters.entrySet()) {
+ members = pair.getValue();
+
+ double medoidCost, minCost = Double.POSITIVE_INFINITY;
+ int rowIdx, colIdx, bestMedoid = 0; // start at 0, not -1 in case of all ties...
+ for (int a : members) { // check cost if A is the medoid...
+
+ medoidCost = 0.0;
+ for (int b : members) {
+ if (a == b)
+ continue;
+
+ rowIdx = FastMath.min(a, b);
+ colIdx = FastMath.max(a, b);
+
+ medoidCost += dist_mat[rowIdx][colIdx];
+ }
+
+ if (medoidCost < minCost) {
+ minCost = medoidCost;
+ bestMedoid = a;
+ }
+ }
+
+ this.reassignedMedoidIdcs[i] = bestMedoid;
+ this.centers.add(data.getRow(bestMedoid));
+ i++;
+ }
+ }
+ }
+
+ /**
+ * Simple container for handling cluster assignments. Given
+ * an array of length m of medoid assignments, and an array of length m
+ * of distances to the medoid, organize the new clusters and compute the total
+ * cost of the new system.
+ *
+ * @author Taylor G Smith
+ */
+ private class ClusterAssignments extends TreeMap> {
+ private static final long serialVersionUID = -7488380079772496168L;
+ final int[] assn;
+ TreeMap costs; // maps medoid idx to cluster cost
+ double total_cst;
+
+ ClusterAssignments(int[] assn, double[] costs) {
+ super();
+
+ // should be equal in length to costs arg
+ this.assn = assn;
+ this.costs = new TreeMap<>();
+
+ int medoid;
+ double cost;
+ ArrayList ref;
+ for (int i = 0; i < assn.length; i++) {
+ medoid = assn[i];
+ cost = costs[i];
+
+ ref = get(medoid); // helps avoid double lookup later
+ if (null == ref) { // not here.
+ ref = new ArrayList();
+ ref.add(i);
+ put(medoid, ref);
+ this.costs.put(medoid, cost);
+ } else {
+ ref.add(i);
+ double d = this.costs.get(medoid);
+ this.costs.put(medoid, d + cost);
+ }
+
+ total_cst += cost;
+ }
+ }
+ }
+
+
+ @Override
+ public Algo getLoggerTag() {
+ return Algo.KMEDOIDS;
+ }
+
+ @Override
+ protected Object[] getModelFitSummaryHeaders() {
+ return new Object[]{
+ "Iter. #", "Converged", "TSS", "Avg Clust. Cost", "Min WSS", "Max BSS", "Wall"
+ };
+ }
+
+ /**
+ * Reorder the labels in order of appearance using the
+ * {@link LabelEncoder}. Also reorder the centroids to correspond
+ * with new label order
+ */
+ protected void reorderLabelsAndCentroids() {
+
+ /*
+ * reorder labels...
+ */
+ final LabelEncoder encoder = new LabelEncoder(labels).fit();
+ labels = encoder.getEncodedLabels();
+
+ int i = 0;
+ centroids = new ArrayList<>();
+ int[] classes = encoder.getClasses();
+ for (int claz : classes) {
+ centroids.add(data.getRow(claz)); // an index, not a counter 0 thru k
+ wss[i++] = med_to_wss.get(claz);
+ }
+ }
+
+ @Override
+ final protected GeometricallySeparable defMetric() {
+ return KMedoids.DEF_DIST;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/KMedoidsParameters.java b/src/main/java/com/clust4j/algo/KMedoidsParameters.java
new file mode 100644
index 00000000..745856e6
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/KMedoidsParameters.java
@@ -0,0 +1,107 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.algo.AbstractCentroidClusterer.InitializationStrategy;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import java.util.Random;
+
+public class KMedoidsParameters extends CentroidClustererParameters {
+ private static final long serialVersionUID = -3288579217568576647L;
+
+ private InitializationStrategy strat = KMedoids.DEF_INIT;
+ private int maxIter = KMedoids.DEF_MAX_ITER;
+
+ public KMedoidsParameters() {
+ this.metric = KMedoids.DEF_DIST;
+ }
+
+ public KMedoidsParameters(int k) {
+ this();
+ this.k = k;
+ }
+
+ @Override
+ public KMedoids fitNewModel(final RealMatrix data) {
+ return new KMedoids(data, this.copy()).fit();
+ }
+
+ @Override
+ public KMedoidsParameters copy() {
+ return new KMedoidsParameters(k)
+ .setMaxIter(maxIter)
+ .setConvergenceCriteria(minChange)
+ .setMetric(metric)
+ .setVerbose(verbose)
+ .setSeed(seed)
+ .setInitializationStrategy(strat)
+ .setForceParallel(parallel);
+ }
+
+ @Override
+ public InitializationStrategy getInitializationStrategy() {
+ return strat;
+ }
+
+ @Override
+ public int getMaxIter() {
+ return maxIter;
+ }
+
+ @Override
+ public KMedoidsParameters setForceParallel(boolean b) {
+ this.parallel = b;
+ return this;
+ }
+
+ @Override
+ public KMedoidsParameters setMetric(final GeometricallySeparable dist) {
+ this.metric = dist; // bad idea in kmedoids
+ return this;
+ }
+
+ public KMedoidsParameters setMaxIter(final int max) {
+ this.maxIter = max;
+ return this;
+ }
+
+ @Override
+ public KMedoidsParameters setConvergenceCriteria(final double min) {
+ this.minChange = min;
+ return this;
+ }
+
+ @Override
+ public KMedoidsParameters setInitializationStrategy(InitializationStrategy init) {
+ this.strat = init;
+ return this;
+ }
+
+ @Override
+ public KMedoidsParameters setSeed(final Random seed) {
+ this.seed = seed;
+ return this;
+ }
+
+ @Override
+ public KMedoidsParameters setVerbose(final boolean v) {
+ this.verbose = v;
+ return this;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/LabelEncoder.java b/src/main/java/com/clust4j/algo/LabelEncoder.java
new file mode 100644
index 00000000..b6fc0671
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/LabelEncoder.java
@@ -0,0 +1,184 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.except.ModelNotFitException;
+import com.clust4j.utils.VecUtils;
+
+import java.util.LinkedHashSet;
+import java.util.TreeMap;
+
+public class LabelEncoder extends BaseModel implements java.io.Serializable {
+ private static final long serialVersionUID = 6618077714920820376L;
+
+ final int[] rawLabels;
+ final int numClasses, n;
+ final int[] classes;
+
+ private volatile TreeMap encodedMapping = null;
+ private volatile TreeMap reverseMapping = null;
+ private volatile int[] encodedLabels = null;
+ private volatile boolean fit = false;
+
+
+ public LabelEncoder(int[] labels) {
+ VecUtils.checkDims(labels);
+
+ final LinkedHashSet unique = VecUtils.unique(labels);
+ numClasses = unique.size();
+ if (numClasses < 2 && !allowSingleClass()) {
+ throw new IllegalArgumentException("y has " + numClasses + " unique class"
+ + (numClasses != 1 ? "es" : "") + " and requires at least two");
+ }
+
+ this.rawLabels = VecUtils.copy(labels);
+ this.n = rawLabels.length;
+
+ int idx = 0;
+ this.classes = new int[numClasses];
+ for (Integer u : unique) classes[idx++] = u.intValue();
+
+ // Initialize mappings
+ encodedMapping = new TreeMap<>();
+ reverseMapping = new TreeMap<>();
+ encodedLabels = new int[n];
+ }
+
+
+ /**
+ * For subclasses that need to have built-in mappings,
+ * this hook should be called in the constructor
+ *
+ * @param key
+ * @param val
+ */
+ protected void addMapping(Integer key, Integer value) {
+ encodedMapping.put(key, value);
+ reverseMapping.put(value, key);
+ }
+
+ /**
+ * Whether or not to allow only a single class mapping
+ *
+ * @return true if allow single class mappings
+ */
+ protected boolean allowSingleClass() {
+ return false;
+ }
+
+ @Override
+ public LabelEncoder fit() {
+ synchronized (fitLock) {
+ if (fit)
+ return this;
+
+ int nextLabel = 0, label;
+ Integer val;
+ for (int i = 0; i < n; i++) {
+ label = rawLabels[i];
+ val = encodedMapping.get(label);
+
+ if (null == val) { // not yet seen
+ val = nextLabel++;
+ encodedMapping.put(label, val);
+ reverseMapping.put(val, label);
+ }
+
+ encodedLabels[i] = val;
+ }
+
+
+ fit = true;
+ return this;
+ }
+ }
+
+ public Integer encodeOrNull(int label) {
+ if (!fit) throw new ModelNotFitException("model not yet fit");
+ return encodedMapping.get(label);
+ }
+
+ public int[] getClasses() {
+ return VecUtils.copy(classes);
+ }
+
+ public int[] getEncodedLabels() {
+ if (!fit) throw new ModelNotFitException("model not yet fit");
+ return VecUtils.copy(encodedLabels);
+ }
+
+ public int getNumClasses() {
+ return numClasses;
+ }
+
+ public int[] getRawLabels() {
+ return VecUtils.copy(rawLabels);
+ }
+
+ public Integer reverseEncodeOrNull(int encodedLabel) {
+ if (!fit) throw new ModelNotFitException("model not yet fit");
+ return reverseMapping.get(encodedLabel);
+ }
+
+ /**
+ * Return an encoded label array back to its original state
+ *
+ * @return
+ * @throws IllegalArgumentException if value not in mappings
+ */
+ public int[] reverseTransform(int[] encodedLabels) {
+ if (!fit) throw new ModelNotFitException("model not yet fit");
+ final int[] out = new int[encodedLabels.length];
+
+ int val;
+ Integer encoding;
+ for (int i = 0; i < out.length; i++) {
+ val = encodedLabels[i];
+ encoding = reverseMapping.get(val);
+
+ if (null == encoding)
+ throw new IllegalArgumentException(encoding + " does not exist in label mappings");
+ out[i] = encoding;
+ }
+
+ return out;
+ }
+
+ /**
+ * Encode a new label array based on the fitted mappings
+ *
+ * @param newLabels
+ * @return
+ * @throws IllegalArgumentException if value not in mappings
+ */
+ public int[] transform(int[] newLabels) {
+ if (!fit) throw new ModelNotFitException("model not yet fit");
+ final int[] out = new int[newLabels.length];
+
+ int val;
+ Integer encoding;
+ for (int i = 0; i < out.length; i++) {
+ val = newLabels[i];
+ encoding = encodedMapping.get(val);
+
+ if (null == encoding)
+ throw new IllegalArgumentException(encoding + " does not exist in label mappings");
+ out[i] = encoding;
+ }
+
+ return out;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/MeanShift.java b/src/main/java/com/clust4j/algo/MeanShift.java
new file mode 100644
index 00000000..7d74e26a
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/MeanShift.java
@@ -0,0 +1,1137 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.except.IllegalClusterStateException;
+import com.clust4j.except.ModelNotFitException;
+import com.clust4j.kernel.GaussianKernel;
+import com.clust4j.kernel.RadialBasisKernel;
+import com.clust4j.log.Log.Tag.Algo;
+import com.clust4j.log.LogTimer;
+import com.clust4j.log.Loggable;
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.metrics.pairwise.SimilarityMetric;
+import com.clust4j.utils.EntryPair;
+import com.clust4j.utils.MatUtils;
+import com.clust4j.utils.VecUtils;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.util.FastMath;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Random;
+import java.util.TreeSet;
+import java.util.concurrent.ConcurrentLinkedDeque;
+import java.util.concurrent.ConcurrentSkipListSet;
+import java.util.concurrent.RejectedExecutionException;
+
+/**
+ * Mean shift is a procedure for locating the maxima of a density function given discrete
+ * data sampled from that function. It is useful for detecting the modes of this density.
+ * This is an iterative method, and we start with an initial estimate x . Let a
+ * {@link RadialBasisKernel} function be given. This function determines the weight of nearby
+ * points for re-estimation of the mean. Typically a {@link GaussianKernel} kernel on the
+ * distance to the current estimate is used.
+ *
+ * @author Taylor G Smith <tgsmith61591@gmail.com>, adapted from sklearn implementation
+ * @see Mean shift on Wikipedia
+ */
+final public class MeanShift
+ extends AbstractDensityClusterer
+ implements CentroidLearner, Convergeable, NoiseyClusterer {
+ /**
+ *
+ */
+ private static final long serialVersionUID = 4423672142693334046L;
+
+ final public static double DEF_BANDWIDTH = 5.0;
+ final public static int DEF_MAX_ITER = 300;
+ final public static int DEF_MIN_BIN_FREQ = 1;
+ final static double incrementAmt = 0.25;
+ final public static HashSet> UNSUPPORTED_METRICS;
+
+
+ /**
+ * Static initializer
+ */
+ static {
+ UNSUPPORTED_METRICS = new HashSet<>();
+ // Add metrics here if necessary... already vetoes any
+ // similarity metrics, so this might be sufficient...
+ }
+
+ @Override
+ final public boolean isValidMetric(GeometricallySeparable geo) {
+ return !UNSUPPORTED_METRICS.contains(geo.getClass()) && !(geo instanceof SimilarityMetric);
+ }
+
+
+ /**
+ * The max iterations
+ */
+ private final int maxIter;
+
+ /**
+ * Min change convergence criteria
+ */
+ private final double tolerance;
+
+ /**
+ * The kernel bandwidth (volatile because can change in sync method)
+ */
+ volatile private double bandwidth;
+
+ /**
+ * Class labels
+ */
+ volatile private int[] labels = null;
+
+ /**
+ * The M x N seeds to be used as initial kernel points
+ */
+ private double[][] seeds;
+
+ /**
+ * Num rows, cols
+ */
+ private final int n;
+
+ /**
+ * Whether bandwidth is auto-estimated
+ */
+ private final boolean autoEstimate;
+
+
+ /**
+ * Track convergence
+ */
+ private volatile boolean converged = false;
+ /**
+ * The centroid records
+ */
+ private volatile ArrayList centroids;
+ private volatile int numClusters;
+ private volatile int numNoisey;
+ /**
+ * Count iterations
+ */
+ private volatile int itersElapsed = 0;
+
+
+ /**
+ * Default constructor
+ *
+ * @param data
+ * @param bandwidth
+ */
+ protected MeanShift(RealMatrix data, final double bandwidth) {
+ this(data, new MeanShiftParameters(bandwidth));
+ }
+
+ /**
+ * Default constructor for auto bandwidth estimation
+ *
+ * @param data
+ * @param bandwidth
+ */
+ protected MeanShift(RealMatrix data) {
+ this(data, new MeanShiftParameters());
+ }
+
+ /**
+ * Constructor with custom MeanShiftPlanner
+ *
+ * @param data
+ * @param planner
+ */
+ protected MeanShift(RealMatrix data, MeanShiftParameters planner) {
+ super(data, planner);
+
+
+ // Check bandwidth...
+ if (planner.getBandwidth() <= 0.0)
+ error(new IllegalArgumentException("bandwidth "
+ + "must be greater than 0.0"));
+
+
+ // Check seeds dimension
+ if (null != planner.getSeeds()) {
+ if (planner.getSeeds().length == 0)
+ error(new IllegalArgumentException("seeds "
+ + "length must be greater than 0"));
+
+ // Throws NonUniformMatrixException if non uniform...
+ MatUtils.checkDimsForUniformity(planner.getSeeds());
+
+ if (planner.getSeeds()[0].length != (n = this.data.getColumnDimension()))
+ error(new DimensionMismatchException(planner.getSeeds()[0].length, n));
+
+ if (planner.getSeeds().length > this.data.getRowDimension())
+ error(new IllegalArgumentException("seeds "
+ + "length cannot exceed number of datapoints"));
+
+ info("initializing kernels from given seeds");
+
+ // Handle the copying in the planner
+ seeds = planner.getSeeds();
+ } else { // Default = all*/
+ info("no seeds provided; defaulting to all datapoints");
+ seeds = this.data.getData(); // use THIS as it's already scaled...
+ n = this.data.getColumnDimension();
+ }
+
+ /*
+ * Check metric for validity
+ */
+ if (!isValidMetric(this.dist_metric)) {
+ warn(this.dist_metric.getName() + " is not valid for " + getName() + ". "
+ + "Falling back to default Euclidean dist");
+ setSeparabilityMetric(DEF_DIST);
+ }
+
+
+ this.maxIter = planner.getMaxIter();
+ this.tolerance = planner.getConvergenceTolerance();
+
+
+ this.autoEstimate = planner.getAutoEstimate();
+ final LogTimer aeTimer = new LogTimer();
+
+
+ /*
+ * Assign bandwidth
+ */
+ this.bandwidth =
+ /* if all singular, just pick a number... */
+ this.singular_value ? 0.5 :
+ /* Otherwise if we're auto-estimating, estimate it */
+ autoEstimate ?
+ autoEstimateBW(this, planner.getAutoEstimationQuantile()) :
+ planner.getBandwidth();
+
+ /*
+ * Give auto-estimation timer update
+ */
+ if (autoEstimate && !this.singular_value) info("bandwidth auto-estimated in " +
+ (parallel ? "parallel in " : "") + aeTimer.toString());
+
+
+ logModelSummary();
+ }
+
+ @Override
+ final protected ModelSummary modelSummary() {
+ return new ModelSummary(new Object[]{
+ "Num Rows", "Num Cols", "Metric", "Bandwidth", "Allow Par.", "Max Iter.", "Tolerance"
+ }, new Object[]{
+ data.getRowDimension(), data.getColumnDimension(),
+ getSeparabilityMetric(),
+ (autoEstimate ? "(auto) " : "") + bandwidth,
+ parallel,
+ maxIter, tolerance
+ });
+ }
+
+ /**
+ * For testing...
+ *
+ * @param data
+ * @param quantile
+ * @param sep
+ * @param seed
+ * @param parallel
+ * @return
+ */
+ final protected static double autoEstimateBW(Array2DRowRealMatrix data,
+ double quantile, GeometricallySeparable sep, Random seed, boolean parallel) {
+
+ return autoEstimateBW(new NearestNeighbors(data,
+ new NearestNeighborsParameters((int) (data.getRowDimension() * quantile))
+ .setSeed(seed)
+ .setForceParallel(parallel)).fit(),
+ data.getDataRef(),
+ quantile,
+ sep, seed,
+ parallel,
+ null);
+ }
+
+ /**
+ * Actually called internally
+ *
+ * @param caller
+ * @param quantile
+ * @return
+ */
+ final protected static double autoEstimateBW(MeanShift caller, double quantile) {
+ LogTimer timer = new LogTimer();
+ NearestNeighbors nn = new NearestNeighbors(caller,
+ new NearestNeighborsParameters((int) (caller.data.getRowDimension() * quantile))
+ .setForceParallel(caller.parallel)).fit();
+ caller.info("fit nearest neighbors model for auto-bandwidth automation in " + timer.toString());
+
+ return autoEstimateBW(nn,
+ caller.data.getDataRef(), quantile, caller.getSeparabilityMetric(),
+ caller.getSeed(), caller.parallel, caller);
+ }
+
+ final protected static double autoEstimateBW(NearestNeighbors nn, double[][] data,
+ double quantile, GeometricallySeparable sep, Random seed, boolean parallel,
+ Loggable logger) {
+
+ if (quantile <= 0 || quantile > 1)
+ throw new IllegalArgumentException("illegal quantile");
+ final int m = data.length;
+
+ double bw = 0.0;
+ final double[][] X = nn.data.getDataRef();
+ final int minsize = ParallelChunkingTask.ChunkingStrategy.DEF_CHUNK_SIZE;
+ final int chunkSize = X.length < minsize ? minsize : X.length / 5;
+ final int numChunks = ParallelChunkingTask.ChunkingStrategy.getNumChunks(chunkSize, m);
+ Neighborhood neighb;
+
+
+ if (!parallel) {
+ /*
+ * For each chunk of 500, get the neighbors and then compute the
+ * sum of the row maxes of the distance matrix.
+ */
+ int chunkStart, nextChunk;
+ for (int chunk = 0; chunk < numChunks; chunk++) {
+ chunkStart = chunk * chunkSize;
+ nextChunk = chunk == numChunks - 1 ? m : chunkStart + chunkSize;
+
+ double[][] nextMatrix = new double[nextChunk - chunkStart][];
+ for (int i = chunkStart, j = 0; i < nextChunk; i++, j++)
+ nextMatrix[j] = X[i];
+
+ neighb = nn.getNeighbors(nextMatrix);
+ for (double[] distRow : neighb.getDistances()) {
+ //bw += VecUtils.max(distRow);
+ bw += distRow[distRow.length - 1]; // it's sorted!
+ }
+ }
+ } else {
+ // Estimate bandwidth in parallel
+ bw = ParallelBandwidthEstimator.doAll(X, nn);
+ }
+
+ return bw / (double) m;
+ }
+
+
+ /**
+ * Estimates the bandwidth of the model in parallel for scalability
+ *
+ * @author Taylor G Smith
+ */
+ static class ParallelBandwidthEstimator
+ extends ParallelChunkingTask
+ implements java.io.Serializable {
+
+ private static final long serialVersionUID = 1171269106158790138L;
+ final NearestNeighbors nn;
+ final int high;
+ final int low;
+
+ ParallelBandwidthEstimator(double[][] X, NearestNeighbors nn) {
+
+ // Use the SimpleChunker
+ super(X);
+
+ this.nn = nn;
+ this.low = 0;
+ this.high = strategy.getNumChunks(X);
+ }
+
+ ParallelBandwidthEstimator(ParallelBandwidthEstimator task, int low, int high) {
+ super(task);
+
+ this.nn = task.nn;
+ this.low = low;
+ this.high = high;
+ }
+
+ @Override
+ protected Double compute() {
+ if (high - low <= 1) { // generally should equal one...
+ return reduce(chunks.get(low));
+ } else {
+ int mid = this.low + (this.high - this.low) / 2;
+ ParallelBandwidthEstimator left = new ParallelBandwidthEstimator(this, low, mid);
+ ParallelBandwidthEstimator right = new ParallelBandwidthEstimator(this, mid, high);
+
+ left.fork();
+ Double l = right.compute();
+ Double r = left.join();
+
+ return l + r;
+ }
+ }
+
+ @Override
+ public Double reduce(Chunk chunk) {
+ double bw = 0.0;
+ Neighborhood neighb = nn.getNeighbors(chunk.get(), false);
+
+ for (double[] distRow : neighb.getDistances()) {
+ //bw += VecUtils.max(distRow);
+ bw += distRow[distRow.length - 1]; // it's sorted!
+ }
+
+ return bw;
+ }
+
+ static double doAll(double[][] X, NearestNeighbors nn) {
+ return getThreadPool().invoke(new ParallelBandwidthEstimator(X, nn));
+ }
+ }
+
+
+ /**
+ * Handles the output for the {@link #singleSeed(double[], RadiusNeighbors, double[][], int)}
+ * method. Implements comparable to be sorted by the value in the entry pair.
+ *
+ * @author Taylor G Smith
+ */
+ protected static class MeanShiftSeed implements Comparable {
+ final double[] dists;
+ /**
+ * The number of points in the bandwidth
+ */
+ final Integer count;
+ final int iterations;
+
+ MeanShiftSeed(final double[] dists, final int count, int iterations) {
+ this.dists = dists;
+ this.count = count;
+ this.iterations = iterations;
+ }
+
+ /*
+ * we don't need these methods in the actual algo, and they just
+ * create more need for testing to get good coverage, so we can
+ * just omit them
+ *
+ @Override
+ public boolean equals(Object o) {
+ if(this == o)
+ return true;
+ if(o instanceof MeanShiftSeed) {
+ MeanShiftSeed m = (MeanShiftSeed)o;
+ return VecUtils.equalsExactly(dists, m.dists)
+ && count.intValue() == m.count.intValue();
+ }
+
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return "{" + Arrays.toString(dists) + " : " + count + "}";
+ }
+
+ @Override
+ public int hashCode() {
+ int h = 31;
+ for(double d: dists)
+ h ^= (int)d;
+ return h ^ count;
+ }
+ */
+
+ EntryPair getPair() {
+ return new EntryPair<>(dists, count);
+ }
+
+ @Override
+ public int compareTo(MeanShiftSeed o2) {
+ int comp = count.compareTo(o2.count);
+
+ if (comp == 0) {
+ final double[] d2 = o2.dists;
+
+ for (int i = 0; i < dists.length; i++) {
+ int c = Double.valueOf(dists[i]).compareTo(d2[i]);
+ if (c != 0)
+ return -c;
+ }
+ }
+
+ return -comp;
+ }
+ }
+
+
+ /**
+ * Light struct to hold summary info
+ *
+ * @author Taylor G Smith
+ */
+ static class SummaryLite {
+ final String name;
+ final int iters;
+ final String fmtTime;
+ final String wallTime;
+ boolean retained = false;
+
+ SummaryLite(final String nm, final int iter,
+ final String fmt, final String wall) {
+ this.name = nm;
+ this.iters = iter;
+ this.fmtTime = fmt;
+ this.wallTime = wall;
+ }
+
+ Object[] toArray() {
+ return new Object[]{
+ name,
+ iters,
+ fmtTime,
+ wallTime,
+ retained
+ };
+ }
+ }
+
+ /**
+ * The superclass for parallelized MeanShift tasks
+ *
+ * @param
+ * @author Taylor G Smith
+ */
+ abstract static class ParallelMSTask extends ParallelChunkingTask {
+ private static final long serialVersionUID = 2139716909891672022L;
+ final ConcurrentLinkedDeque summaries;
+ final double[][] X;
+
+ ParallelMSTask(double[][] X, ConcurrentLinkedDeque summaries) {
+ super(X);
+ this.summaries = summaries;
+ this.X = X;
+ }
+
+ ParallelMSTask(ParallelMSTask task) {
+ super(task);
+ this.summaries = task.summaries;
+ this.X = task.X;
+ }
+
+ public String formatName(String str) {
+ StringBuilder sb = new StringBuilder();
+ boolean hyphen = false; // have we hit the hyphen yet?
+ boolean started_worker = false;
+ boolean seen_k = false;
+ boolean finished_worker = false;
+
+ for (char c : str.toCharArray()) {
+ if (hyphen || Character.isUpperCase(c)) {
+ if (started_worker && !finished_worker) {
+ if (c == 'k') { // past first 'r'...
+ seen_k = true;
+ continue;
+ }
+
+ // in the middle of the word "worker"
+ if (c != 'r')
+ continue;
+ else if (!seen_k)
+ continue;
+
+ // At the last char in 'worker'
+ finished_worker = true;
+ sb.append("Kernel");
+ } else if (!started_worker && c == 'w') {
+ started_worker = true;
+ } else {
+ sb.append(c);
+ }
+ } else if ('-' == c) {
+ hyphen = true;
+ sb.append(c);
+ }
+ }
+
+ return sb.toString();
+ }
+ }
+
+ /**
+ * Class that handles construction of the center intensity object
+ *
+ * @author Taylor G Smith
+ */
+ static abstract class CenterIntensity implements java.io.Serializable, Iterable {
+ private static final long serialVersionUID = -6535787295158719610L;
+
+ abstract int getIters();
+
+ abstract boolean isEmpty();
+
+ abstract ArrayList getSummaries();
+
+ abstract int size();
+ }
+
+ /**
+ * A class that utilizes a {@link java.util.concurrent.ForkJoinPool}
+ * as parallel executors to run many tasks across multiple cores.
+ *
+ * @author Taylor G Smith
+ */
+ static class ParallelSeedExecutor
+ extends ParallelMSTask> {
+
+ private static final long serialVersionUID = 632871644265502894L;
+
+ final int maxIter;
+ final RadiusNeighbors nbrs;
+
+ final ConcurrentSkipListSet computedSeeds;
+ final int high, low;
+
+
+ ParallelSeedExecutor(
+ int maxIter, double[][] X, RadiusNeighbors nbrs,
+ ConcurrentLinkedDeque summaries) {
+
+ /**
+ * Pass summaries reference to super
+ */
+ super(X, summaries);
+
+ this.maxIter = maxIter;
+ this.nbrs = nbrs;
+ this.computedSeeds = new ConcurrentSkipListSet<>();
+ this.low = 0;
+ this.high = strategy.getNumChunks(X);
+ }
+
+ ParallelSeedExecutor(ParallelSeedExecutor task, int low, int high) {
+ super(task);
+
+ this.maxIter = task.maxIter;
+ this.nbrs = task.nbrs;
+ this.computedSeeds = task.computedSeeds;
+ this.high = high;
+ this.low = low;
+ }
+
+ @Override
+ protected ConcurrentSkipListSet compute() {
+ if (high - low <= 1) { // generally should equal one...
+ return reduce(chunks.get(low));
+
+ } else {
+ int mid = this.low + (this.high - this.low) / 2;
+ ParallelSeedExecutor left = new ParallelSeedExecutor(this, low, mid);
+ ParallelSeedExecutor right = new ParallelSeedExecutor(this, mid, high);
+
+ left.fork();
+ right.compute();
+ left.join();
+
+ return computedSeeds;
+ }
+ }
+
+ @Override
+ public ConcurrentSkipListSet reduce(Chunk chunk) {
+ for (double[] seed : chunk.get()) {
+ MeanShiftSeed ms = singleSeed(seed, nbrs, X, maxIter);
+ if (null == ms)
+ continue;
+
+ computedSeeds.add(ms);
+ String nm = getName();
+ summaries.add(new SummaryLite(
+ nm,
+ ms.iterations,
+ timer.formatTime(),
+ timer.wallTime()
+ ));
+ }
+
+ return computedSeeds;
+ }
+
+ static ConcurrentSkipListSet doAll(
+ int maxIter, double[][] X, RadiusNeighbors nbrs,
+ ConcurrentLinkedDeque summaries) {
+
+ return getThreadPool().invoke(
+ new ParallelSeedExecutor(
+ maxIter, X, nbrs,
+ summaries));
+ }
+ }
+
+ class ParallelCenterIntensity extends CenterIntensity {
+ private static final long serialVersionUID = 4392163493242956320L;
+
+ final ConcurrentSkipListSet itrz = new ConcurrentSkipListSet<>();
+ final ConcurrentSkipListSet computedSeeds;
+
+ /**
+ * Serves as a reference for passing to parallel job
+ */
+ final ConcurrentLinkedDeque summaries = new ConcurrentLinkedDeque<>();
+
+ final LogTimer timer;
+ final RadiusNeighbors nbrs;
+
+ ParallelCenterIntensity(RadiusNeighbors nbrs) {
+
+ this.nbrs = nbrs;
+ this.timer = new LogTimer();
+
+ // Execute forkjoinpool
+ this.computedSeeds = ParallelSeedExecutor.doAll(maxIter, seeds, nbrs, summaries);
+ for (MeanShiftSeed sd : computedSeeds)
+ itrz.add(sd.iterations);
+ }
+
+ @Override
+ public int getIters() {
+ return itrz.last();
+ }
+
+ @Override
+ public ArrayList getSummaries() {
+ return new ArrayList<>(summaries);
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return computedSeeds.isEmpty();
+ }
+
+ @Override
+ public Iterator iterator() {
+ return computedSeeds.iterator();
+ }
+
+ @Override
+ public int size() {
+ return computedSeeds.size();
+ }
+ }
+
+ /**
+ * Compute the center intensity entry pairs serially and call the
+ * {@link MeanShift#singleSeed(double[], RadiusNeighbors, double[][], int)} method
+ *
+ * @author Taylor G Smith
+ */
+ class SerialCenterIntensity extends CenterIntensity {
+ private static final long serialVersionUID = -1117327079708746405L;
+
+ int itrz = 0;
+ final TreeSet computedSeeds;
+ final ArrayList summaries = new ArrayList<>();
+
+ SerialCenterIntensity(RadiusNeighbors nbrs) {
+
+ LogTimer timer;
+
+ // Now get single seed members
+ MeanShiftSeed sd;
+ this.computedSeeds = new TreeSet<>();
+ final double[][] X = data.getData();
+
+ int idx = 0;
+ for (double[] seed : seeds) {
+ idx++;
+ timer = new LogTimer();
+ sd = singleSeed(seed, nbrs, X, maxIter);
+
+ if (null == sd)
+ continue;
+
+ computedSeeds.add(sd);
+ itrz = FastMath.max(itrz, sd.iterations);
+
+ // If it actually converged, add the summary
+ summaries.add(new SummaryLite(
+ "Kernel " + (idx - 1), sd.iterations,
+ timer.formatTime(), timer.wallTime()
+ ));
+ }
+ }
+
+ @Override
+ public int getIters() {
+ return itrz;
+ }
+
+ @Override
+ public ArrayList getSummaries() {
+ return summaries;
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return computedSeeds.isEmpty();
+ }
+
+ @Override
+ public Iterator iterator() {
+ return computedSeeds.iterator();
+ }
+
+ @Override
+ public int size() {
+ return computedSeeds.size();
+ }
+ }
+
+
+ /**
+ * Get the kernel bandwidth
+ *
+ * @return kernel bandwidth
+ */
+ public double getBandwidth() {
+ return bandwidth;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public boolean didConverge() {
+ return converged;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public int itersElapsed() {
+ return itersElapsed;
+ }
+
+ /**
+ * Returns a copy of the seeds matrix
+ *
+ * @return
+ */
+ public double[][] getKernelSeeds() {
+ return MatUtils.copy(seeds);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public int getMaxIter() {
+ return maxIter;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public double getConvergenceTolerance() {
+ return tolerance;
+ }
+
+ @Override
+ public String getName() {
+ return "MeanShift";
+ }
+
+
+ @Override
+ public Algo getLoggerTag() {
+ return com.clust4j.log.Log.Tag.Algo.MEANSHIFT;
+ }
+
+
+ @Override
+ protected MeanShift fit() {
+ synchronized (fitLock) {
+
+ if (null != labels) // Already fit this model
+ return this;
+
+
+ // Put the results into a Map (hash because tree imposes comparable casting)
+ final LogTimer timer = new LogTimer();
+ centroids = new ArrayList();
+
+
+ /*
+ * Get the neighborhoods and center intensity object. Will iterate until
+ * either the centers are found, or the max try count is exceeded. For each
+ * iteration, will increase bandwidth.
+ */
+ RadiusNeighbors nbrs = new RadiusNeighbors(
+ this, bandwidth).fit();
+
+
+ // Compute the seeds and center intensity
+ // If parallelism is permitted, try it.
+ CenterIntensity intensity = null;
+ if (parallel) {
+ try {
+ intensity = new ParallelCenterIntensity(nbrs);
+ } catch (RejectedExecutionException e) {
+ // Shouldn't happen...
+ warn("parallel search failed; falling back to serial");
+ }
+ }
+
+ // Gets here if serial or if parallel failed...
+ if (null == intensity)
+ intensity = new SerialCenterIntensity(nbrs);
+
+
+ // Check for points all too far from seeds
+ if (intensity.isEmpty()) {
+ error(new IllegalClusterStateException("No point "
+ + "was within bandwidth=" + bandwidth
+ + " of any seed; try increasing bandwidth"));
+ } else {
+ converged = true;
+ itersElapsed = intensity.getIters(); // max iters elapsed
+ }
+
+
+ // Extract the centroids
+ int idx = 0, m_prime = intensity.size();
+ final Array2DRowRealMatrix sorted_centers = new Array2DRowRealMatrix(m_prime, n);
+
+ for (MeanShiftSeed entry : intensity)
+ sorted_centers.setRow(idx++, entry.getPair().getKey());
+
+ // Fit the new neighbors model
+ nbrs = new RadiusNeighbors(sorted_centers,
+ new RadiusNeighborsParameters(bandwidth)
+ .setSeed(this.random_state)
+ .setMetric(this.dist_metric)
+ .setForceParallel(parallel), true).fit();
+
+
+ // Post-processing. Remove near duplicate seeds
+ // If dist btwn two kernels is less than bandwidth, remove one w fewer pts
+ // Create a boolean mask, init true
+ final boolean[] unique = new boolean[m_prime];
+ for (int i = 0; i < unique.length; i++) unique[i] = true;
+
+
+ // Pre-filtered summaries...
+ ArrayList allSummary = intensity.getSummaries();
+
+
+ // Iterate over sorted centers and query radii
+ int redundant_ct = 0;
+ int[] indcs;
+ double[] center;
+ for (int i = 0; i < m_prime; i++) {
+ if (unique[i]) {
+ center = sorted_centers.getRow(i);
+ indcs = nbrs.getNeighbors(
+ new double[][]{center},
+ bandwidth, false)
+ .getIndices()[0];
+
+ for (int id : indcs)
+ unique[id] = false;
+
+ unique[i] = true; // Keep this as true
+ }
+ }
+
+
+ // Now assign the centroids...
+ SummaryLite summ;
+ for (int i = 0; i < unique.length; i++) {
+ summ = allSummary.get(i);
+
+ if (unique[i]) {
+ summ.retained = true;
+ centroids.add(sorted_centers.getRow(i));
+ }
+
+ fitSummary.add(summ.toArray());
+ }
+
+
+ // calc redundant ct
+ redundant_ct = unique.length - centroids.size();
+
+
+ // also put the centroids into a matrix. We have to
+ // wait to perform this op, because we have to know
+ // the size of centroids first...
+ Array2DRowRealMatrix centers = new Array2DRowRealMatrix(centroids.size(), n);
+ for (int i = 0; i < centroids.size(); i++)
+ centers.setRow(i, centroids.get(i));
+
+
+ // Build yet another neighbors model...
+ NearestNeighbors nn = new NearestNeighbors(centers,
+ new NearestNeighborsParameters(1)
+ .setSeed(this.random_state)
+ .setMetric(this.dist_metric)
+ .setForceParallel(false), true).fit();
+
+
+ info((numClusters = centroids.size()) + " optimal kernel" + (numClusters != 1 ? "s" : "") + " identified");
+ info(redundant_ct + " nearly-identical kernel" + (redundant_ct != 1 ? "s" : "") + " removed");
+
+
+ // Get the nearest...
+ final LogTimer clustTimer = new LogTimer();
+ Neighborhood knrst = nn.getNeighbors(data.getDataRef());
+ labels = MatUtils.flatten(knrst.getIndices());
+
+
+ // order the labels..
+ /*
+ * Reduce labels to a sorted, gapless, list
+ * sklearn line: cluster_centers_indices = np.unique(labels)
+ */
+ ArrayList centroidIndices = new ArrayList(numClusters);
+ for (Integer i : labels) // force autobox
+ if (!centroidIndices.contains(i)) // Not race condition because synchronized
+ centroidIndices.add(i);
+
+ /*
+ * final label assignment...
+ * sklearn line: labels = np.searchsorted(cluster_centers_indices, labels)
+ */
+ for (int i = 0; i < labels.length; i++)
+ labels[i] = centroidIndices.indexOf(labels[i]);
+
+
+ // Wrap up...
+ // Count missing
+ numNoisey = 0;
+ for (int lab : labels) if (lab == NOISE_CLASS) numNoisey++;
+ info(numNoisey + " record" + (numNoisey != 1 ? "s" : "") + " classified noise");
+
+
+ info("completed cluster labeling in " + clustTimer.toString());
+
+
+ sayBye(timer);
+ return this;
+ }
+
+ } // End train
+
+
+ @Override
+ public ArrayList getCentroids() {
+ if (null != centroids) {
+ final ArrayList cent = new ArrayList();
+ for (double[] d : centroids)
+ cent.add(VecUtils.copy(d));
+
+ return cent;
+ } else {
+ error(new ModelNotFitException("model has not yet been fit"));
+ return null; // can't happen
+ }
+ }
+
+ @Override
+ public int[] getLabels() {
+ return super.handleLabelCopy(labels);
+ }
+
+ static MeanShiftSeed singleSeed(double[] seed, RadiusNeighbors rn, double[][] X, int maxIter) {
+ final double bandwidth = rn.getRadius(), tolerance = 1e-3;
+ final int n = X[0].length; // we know X is uniform
+ int completed_iterations = 0;
+
+ double norm, diff;
+
+ while (true) {
+
+ Neighborhood nbrs = rn.getNeighbors(new double[][]{seed}, bandwidth, false);
+ int[] i_nbrs = nbrs.getIndices()[0];
+
+ // Check if exit
+ if (i_nbrs.length == 0)
+ break;
+
+ // Save the old seed
+ final double[] oldSeed = seed;
+
+ // Get the points inside and simultaneously calc new seed
+ final double[] newSeed = new double[n];
+ norm = 0;
+ diff = 0;
+ for (int i = 0; i < i_nbrs.length; i++) {
+ final double[] record = X[i_nbrs[i]];
+
+ for (int j = 0; j < n; j++) {
+ newSeed[j] += record[j];
+
+ // Last iter hack, go ahead and compute means simultaneously
+ if (i == i_nbrs.length - 1) {
+ newSeed[j] /= (double) i_nbrs.length;
+ diff = newSeed[j] - oldSeed[j];
+ norm += diff * diff;
+ }
+ }
+ }
+
+ // Assign the new seed
+ seed = newSeed;
+ norm = FastMath.sqrt(norm);
+
+ // Check stopping criteria
+ if (completed_iterations++ == maxIter || norm < tolerance)
+ return new MeanShiftSeed(seed, i_nbrs.length, completed_iterations);
+ }
+
+ // Default... shouldn't get here though
+ return null;
+ }
+
+
+ @Override
+ final protected Object[] getModelFitSummaryHeaders() {
+ return new Object[]{
+ "Seed ID", "Iterations", "Iter. Time", "Wall", "Retained"
+ };
+ }
+
+ @Override
+ public int getNumberOfIdentifiedClusters() {
+ return numClusters;
+ }
+
+ @Override
+ public int getNumberOfNoisePoints() {
+ return numNoisey;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public int[] predict(RealMatrix newData) {
+ return CentroidUtils.predict(this, newData);
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/MeanShiftParameters.java b/src/main/java/com/clust4j/algo/MeanShiftParameters.java
new file mode 100644
index 00000000..9b36efc5
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/MeanShiftParameters.java
@@ -0,0 +1,145 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+package com.clust4j.algo;
+
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+import com.clust4j.utils.MatUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+
+import java.util.Random;
+
+/**
+ * A builder class to provide an easier constructing
+ * interface to set custom parameters for DBSCAN
+ *
+ * @author Taylor G Smith
+ */
+final public class MeanShiftParameters
+ extends BaseClustererParameters
+ implements UnsupervisedClassifierParameters {
+
+ private static final long serialVersionUID = -2276248235151049820L;
+ private boolean autoEstimateBW = false;
+ private double autoEstimateBWQuantile = 0.3;
+ private double bandwidth = MeanShift.DEF_BANDWIDTH;
+ private int maxIter = MeanShift.DEF_MAX_ITER;
+ private double minChange = MeanShift.DEF_TOL;
+ private double[][] seeds = null;
+
+
+ public MeanShiftParameters() {
+ this.autoEstimateBW = true;
+ }
+
+ public MeanShiftParameters(final double bandwidth) {
+ this.bandwidth = bandwidth;
+ }
+
+
+ public boolean getAutoEstimate() {
+ return autoEstimateBW;
+ }
+
+ public double getAutoEstimationQuantile() {
+ return autoEstimateBWQuantile;
+ }
+
+ public double getBandwidth() {
+ return bandwidth;
+ }
+
+ public double[][] getSeeds() {
+ return seeds;
+ }
+
+ public int getMaxIter() {
+ return maxIter;
+ }
+
+ public double getConvergenceTolerance() {
+ return minChange;
+ }
+
+ @Override
+ public MeanShift fitNewModel(RealMatrix data) {
+ return new MeanShift(data, this.copy()).fit();
+ }
+
+ @Override
+ public MeanShiftParameters copy() {
+ return new MeanShiftParameters(bandwidth)
+ .setAutoBandwidthEstimation(autoEstimateBW)
+ .setAutoBandwidthEstimationQuantile(autoEstimateBWQuantile)
+ .setMaxIter(maxIter)
+ .setMinChange(minChange)
+ .setSeed(seed)
+ .setSeeds(seeds)
+ .setMetric(metric)
+ .setVerbose(verbose)
+ .setForceParallel(parallel);
+ }
+
+ public MeanShiftParameters setAutoBandwidthEstimation(boolean b) {
+ this.autoEstimateBW = b;
+ return this;
+ }
+
+ public MeanShiftParameters setAutoBandwidthEstimationQuantile(double d) {
+ this.autoEstimateBWQuantile = d;
+ return this;
+ }
+
+ public MeanShiftParameters setMaxIter(final int max) {
+ this.maxIter = max;
+ return this;
+ }
+
+ public MeanShiftParameters setMinChange(final double min) {
+ this.minChange = min;
+ return this;
+ }
+
+ @Override
+ public MeanShiftParameters setSeed(final Random seed) {
+ this.seed = seed;
+ return this;
+ }
+
+ public MeanShiftParameters setSeeds(final double[][] seeds) {
+ if (null != seeds)
+ this.seeds = MatUtils.copy(seeds);
+ return this;
+ }
+
+ @Override
+ public MeanShiftParameters setMetric(final GeometricallySeparable dist) {
+ this.metric = dist;
+ return this;
+ }
+
+ @Override
+ public MeanShiftParameters setVerbose(final boolean v) {
+ this.verbose = v;
+ return this;
+ }
+
+ @Override
+ public MeanShiftParameters setForceParallel(boolean b) {
+ this.parallel = b;
+ return this;
+ }
+}
diff --git a/src/main/java/com/clust4j/algo/MetricValidator.java b/src/main/java/com/clust4j/algo/MetricValidator.java
new file mode 100644
index 00000000..942e6838
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/MetricValidator.java
@@ -0,0 +1,22 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import com.clust4j.metrics.pairwise.GeometricallySeparable;
+
+public interface MetricValidator {
+ public boolean isValidMetric(GeometricallySeparable geo);
+}
diff --git a/src/main/java/com/clust4j/algo/ModelSummary.java b/src/main/java/com/clust4j/algo/ModelSummary.java
new file mode 100644
index 00000000..d318fc92
--- /dev/null
+++ b/src/main/java/com/clust4j/algo/ModelSummary.java
@@ -0,0 +1,34 @@
+/*******************************************************************************
+ * Copyright 2015, 2016 Taylor G Smith
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+package com.clust4j.algo;
+
+import java.util.ArrayList;
+
+/**
+ * The {@link com.clust4j.utils.TableFormatter} uses this class
+ * for pretty printing of various models' fit summaries.
+ *
+ * @author Taylor G Smith
+ */
+public class ModelSummary extends ArrayList