diff --git a/README.md b/README.md index b4620cf..0550cf2 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,19 @@ umap.fit(data); const transformed = umap.transform(additionalData); ``` +#### Serialization + +```javascript +import { UMAP } from 'umap-js'; + +const umap = new UMAP(); +const embedding = umap.fit(data); +const serialized = umap.serialize(); +const umapCopy = UMAP.deserialize(serialized); +``` + +#### Asynchronous fitting + #### Parameters The UMAP constructor can accept a number of hyperparameters via a `UMAPParameters` object, with the most common described below. See [umap.ts](./src/umap.ts) for more details. diff --git a/src/matrix.ts b/src/matrix.ts index 815760c..4da16b8 100644 --- a/src/matrix.ts +++ b/src/matrix.ts @@ -21,6 +21,12 @@ import * as utils from './utils'; type Entry = { value: number; row: number; col: number }; +export type SerializedSparseMatrix = { + entries: [string, Entry][]; + nRows: number; + nCols: number; +}; + /** * Internal 2-dimensional sparse matrix class */ @@ -142,6 +148,29 @@ export class SparseMatrix { }); return output; } + + setEntries(entries: [string, Entry][]) { + this.entries = new Map(entries); + } + + serialize(): SerializedSparseMatrix { + return { + nRows: this.nRows, + nCols: this.nCols, + entries: Array.from(this.entries.entries()), + }; + } + + static deserialize(serMatrix: SerializedSparseMatrix): SparseMatrix { + const sparseMatrix = new SparseMatrix( + [], + [], + [], + [serMatrix.nRows, serMatrix.nCols] + ); + sparseMatrix.setEntries(serMatrix.entries); + return sparseMatrix; + } } /** diff --git a/src/tree.ts b/src/tree.ts index 99da3b0..f58b0d7 100644 --- a/src/tree.ts +++ b/src/tree.ts @@ -72,6 +72,13 @@ interface RandomProjectionTreeNode { offset?: number; } +export type SerializedFlatTree = { + hyperplanes: number[][]; + offsets: number[]; + children: number[][]; + indices: number[][]; +}; + export class FlatTree { constructor( public hyperplanes: number[][], @@ -79,6 +86,24 @@ export class FlatTree { public children: number[][], public indices: number[][] ) {} + + serialize(): SerializedFlatTree { + return { + hyperplanes: this.hyperplanes, + offsets: this.offsets, + children: this.children, + indices: this.indices, + }; + } + + static deserialize(serTree: SerializedFlatTree): FlatTree { + return new FlatTree( + serTree.hyperplanes, + serTree.offsets, + serTree.children, + serTree.indices + ); + } } /** diff --git a/src/umap.ts b/src/umap.ts index 14edea5..2e76301 100644 --- a/src/umap.ts +++ b/src/umap.ts @@ -75,6 +75,38 @@ export const enum TargetMetric { l2 = 'l2', } +export type SerializedUMAP = { + learningRate: number; + localConnectivity: number; + minDist: number; + nComponents: number; + nEpochs: number; + nNeighbors: number; + negativeSampleRate: number; + repulsionStrength: number; + setOpMixRatio: number; + spread: number; + transformQueueSize: number; + + targetMetric: TargetMetric; + targetWeight: number; + targetNNeighbors: number; + + knnIndices?: number[][]; + knnDistances?: number[][]; + + graph: matrix.SerializedSparseMatrix; + X: number[][]; + isInitialized: boolean; + rpForest: tree.SerializedFlatTree[]; + searchGraph: matrix.SerializedSparseMatrix; + + Y?: number[]; + embedding: number[][]; + + optimizationState: SerializedOptimizationState; +}; + const SMOOTH_K_TOLERANCE = 1e-5; const MIN_K_DIST_SCALE = 1e-3; @@ -323,7 +355,11 @@ export class UMAP { */ initializeFit(X: Vectors): number { if (X.length <= this.nNeighbors) { - throw new Error(`Not enough data points (${X.length}) to create nNeighbors: ${this.nNeighbors}. Add more data points or adjust the configuration.`); + throw new Error( + `Not enough data points (${X.length}) to create nNeighbors: ${ + this.nNeighbors + }. Add more data points or adjust the configuration.` + ); } // We don't need to reinitialize if we've already initialized for this data. @@ -1083,6 +1119,101 @@ export class UMAP { return 200; } } + + setParameters(params: SerializedUMAP) { + this.learningRate = params.learningRate; + this.localConnectivity = params.localConnectivity; + this.minDist = params.minDist; + this.nComponents = params.nComponents; + this.nEpochs = params.nEpochs; + this.nNeighbors = params.nNeighbors; + this.negativeSampleRate = params.negativeSampleRate; + this.repulsionStrength = params.repulsionStrength; + this.setOpMixRatio = params.setOpMixRatio; + this.spread = params.spread; + this.transformQueueSize = params.transformQueueSize; + this.targetMetric = params.targetMetric; + this.targetWeight = params.targetWeight; + this.targetNNeighbors = params.targetNNeighbors; + this.knnIndices = params.knnIndices; + this.knnDistances = params.knnDistances; + this.graph = matrix.SparseMatrix.deserialize(params.graph); + this.X = params.X; + this.isInitialized = params.isInitialized; + this.rpForest = params.rpForest.map(tree.FlatTree.deserialize); + this.searchGraph = matrix.SparseMatrix.deserialize(params.searchGraph); + this.Y = params.Y; + this.embedding = params.embedding; + this.optimizationState = OptimizationState.deserialize( + params.optimizationState + ); + } + + serialize(): SerializedUMAP { + const { + learningRate, + localConnectivity, + minDist, + nComponents, + nEpochs, + nNeighbors, + negativeSampleRate, + repulsionStrength, + setOpMixRatio, + spread, + transformQueueSize, + targetMetric, + targetWeight, + targetNNeighbors, + knnIndices, + knnDistances, + graph, + X, + isInitialized, + rpForest, + searchGraph, + Y, + embedding, + optimizationState, + } = this; + + return { + learningRate, + localConnectivity, + minDist, + nComponents, + nEpochs, + nNeighbors, + negativeSampleRate, + repulsionStrength, + setOpMixRatio, + spread, + transformQueueSize, + targetMetric, + targetWeight, + targetNNeighbors, + knnIndices, + knnDistances, + graph: graph.serialize(), + X, + isInitialized, + rpForest: rpForest.map(t => t.serialize()), + searchGraph: searchGraph.serialize(), + Y, + embedding, + optimizationState: optimizationState.serialize(), + }; + } + + static deserialize( + serUmap: SerializedUMAP, + params: Pick = {} + ): UMAP { + const umap = new UMAP(params); + umap.setParameters(serUmap); + umap.makeSearchFns(); + return umap; + } } export function euclidean(x: Vector, y: Vector) { @@ -1113,6 +1244,29 @@ export function cosine(x: Vector, y: Vector) { } } +type SerializedOptimizationState = { + currentEpoch: number; + + // Data tracked during optimization steps. + headEmbedding: number[][]; + tailEmbedding: number[][]; + head: number[]; + tail: number[]; + epochsPerSample: number[]; + epochOfNextSample: number[]; + epochOfNextNegativeSample: number[]; + epochsPerNegativeSample: number[]; + moveOther: boolean; + initialAlpha: number; + alpha: number; + gamma: number; + a: number; + b: number; + dim: number; + nEpochs: number; + nVertices: number; +}; + /** * An interface representing the optimization state tracked between steps of * the SGD optimization @@ -1138,6 +1292,77 @@ class OptimizationState { dim = 2; nEpochs = 500; nVertices = 0; + + serialize(): SerializedOptimizationState { + const { + currentEpoch, + headEmbedding, + tailEmbedding, + head, + tail, + epochsPerSample, + epochOfNextSample, + epochOfNextNegativeSample, + epochsPerNegativeSample, + moveOther, + initialAlpha, + alpha, + gamma, + a, + b, + dim, + nEpochs, + nVertices, + } = this; + + return { + currentEpoch, + headEmbedding, + tailEmbedding, + head, + tail, + epochsPerSample, + epochOfNextSample, + epochOfNextNegativeSample, + epochsPerNegativeSample, + moveOther, + initialAlpha, + alpha, + gamma, + a, + b, + dim, + nEpochs, + nVertices, + }; + } + + static deserialize(serState: SerializedOptimizationState): OptimizationState { + const optimizationState = new OptimizationState(); + + optimizationState.currentEpoch = serState.currentEpoch; + optimizationState.headEmbedding = serState.headEmbedding; + optimizationState.tailEmbedding = serState.tailEmbedding; + optimizationState.head = serState.head; + optimizationState.tail = serState.tail; + optimizationState.epochsPerSample = serState.epochsPerSample; + optimizationState.epochOfNextSample = serState.epochOfNextSample; + optimizationState.epochOfNextNegativeSample = + serState.epochOfNextNegativeSample; + optimizationState.epochsPerNegativeSample = + serState.epochsPerNegativeSample; + optimizationState.moveOther = serState.moveOther; + optimizationState.initialAlpha = serState.initialAlpha; + optimizationState.alpha = serState.alpha; + optimizationState.gamma = serState.gamma; + optimizationState.a = serState.a; + optimizationState.b = serState.b; + optimizationState.dim = serState.dim; + optimizationState.nEpochs = serState.nEpochs; + optimizationState.nVertices = serState.nVertices; + + return optimizationState; + } } /** diff --git a/test/serialize.test.ts b/test/serialize.test.ts new file mode 100644 index 0000000..06debbe --- /dev/null +++ b/test/serialize.test.ts @@ -0,0 +1,100 @@ +/** + * @license + * + * Copyright 2019 Google LLC. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================== + */ + +import { + UMAP, + findABParams, + euclidean, + RandomFn, + TargetMetric, + Vector, +} from '../src/umap'; +import * as utils from '../src/utils'; +import { + additionalData, + additionalLabels, + testData, + testLabels, + testResults2D, + testResults3D, +} from './test_data'; +import Prando from 'prando'; + +describe('UMAP', () => { + let random: RandomFn; + + // Expected "clustering" ratios, representing inter-cluster distance vs mean + // distance to other points. + const UNSUPERVISED_CLUSTER_RATIO = 0.15; + const SUPERVISED_CLUSTER_RATIO = 0.04; + + beforeEach(() => { + const prng = new Prando(42); + random = () => prng.next(); + }); + + test('transforms an additional point after fitting', () => { + const umap = new UMAP({ random, nComponents: 2 }); + const embedding = umap.fit(testData); + + const serUmap = JSON.stringify(umap.serialize()); + const umapCopy = UMAP.deserialize(JSON.parse(serUmap)); + + const additional = additionalData[0]; + const cpTransformed = umapCopy.transform([additional]); + + const cpNearestIndex = getNearestNeighborIndex(embedding, cpTransformed[0]); + const cpNearestLabel = testLabels[cpNearestIndex]; + expect(cpNearestLabel).toEqual(additionalLabels[0]); + }); + + test('transforms additional points after fitting', () => { + const umap = new UMAP({ random, nComponents: 2 }); + const embedding = umap.fit(testData); + + const serUmap = JSON.stringify(umap.serialize()); + const umapCopy = UMAP.deserialize(JSON.parse(serUmap)); + + const transformed = umapCopy.transform(additionalData); + + for (let i = 0; i < transformed.length; i++) { + const nearestIndex = getNearestNeighborIndex(embedding, transformed[i]); + const nearestLabel = testLabels[nearestIndex]; + expect(nearestLabel).toEqual(additionalLabels[i]); + } + }); +}); + +function getNearestNeighborIndex( + items: number[][], + otherPoint: number[], + distanceFn = euclidean +) { + const nearest = items.reduce( + (result, point, pointIndex) => { + const pointDistance = distanceFn(point, otherPoint); + if (pointDistance < result.distance) { + return { index: pointIndex, distance: pointDistance }; + } + return result; + }, + { index: 0, distance: Infinity } + ); + return nearest.index; +} diff --git a/test/umap.test.ts b/test/umap.test.ts index 812455b..0be24a8 100644 --- a/test/umap.test.ts +++ b/test/umap.test.ts @@ -206,8 +206,10 @@ describe('UMAP', () => { test('initializeFit throws helpful error if not enough data', () => { const umap = new UMAP({ random }); const smallData = testData.slice(0, 15); - expect(() => umap.initializeFit(smallData)).toThrow(/Not enough data points/); - }) + expect(() => umap.initializeFit(smallData)).toThrow( + /Not enough data points/ + ); + }); }); function computeMeanDistances(vectors: number[][]) { diff --git a/yarn.lock b/yarn.lock index 85769ac..f4505ca 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4465,9 +4465,10 @@ typedarray@^0.0.6: version "0.0.6" resolved "https://registry.yarnpkg.com/typedarray/-/typedarray-0.0.6.tgz#867ac74e3864187b1d3d47d996a78ec5c8830777" -typescript@^3.2.4: - version "3.2.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.2.4.tgz#c585cb952912263d915b462726ce244ba510ef3d" +typescript@^3.9.10: + version "3.9.10" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.9.10.tgz#70f3910ac7a51ed6bef79da7800690b19bf778b8" + integrity sha512-w6fIxVE/H1PkLKcCPsFqKE7Kv7QUwhU8qQY2MueZXWx5cPZdwFupLgKK3vntcK98BtNHZtAF4LA/yl2a7k8R6Q== uglify-js@^3.1.4: version "3.6.0"