From 0cec4fc3a005d04820a242679f52fe4f8d21fe01 Mon Sep 17 00:00:00 2001 From: LTLA Date: Mon, 14 Oct 2024 14:34:50 -0700 Subject: [PATCH 1/3] Added utility to truncate the neighbor search results. This allows us to re-use the same search results in multiple NN-dependent functions that need different numbers of neighbors. --- js/findNearestNeighbors.js | 18 ++++++++++++++++++ src/NeighborIndex.cpp | 18 ++++++++++++++++++ tests/findNearestNeighbors.test.js | 29 +++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/js/findNearestNeighbors.js b/js/findNearestNeighbors.js index b1eac503..0ac60860 100644 --- a/js/findNearestNeighbors.js +++ b/js/findNearestNeighbors.js @@ -275,3 +275,21 @@ export function findNearestNeighbors(x, k, options = {}) { FindNearestNeighborsResults ); } + +/** + * Truncate existing neighbor search results to the `k` nearest neighbors for each cell. + * This is exactly or approximately equal to calling {@linkcode findNearestNeighbors} with the new `k`, + * depending on whether `approximate = false` or `approximate = true` was used to build the search index, respectively. + * + * @param {FindNEarestNeighborsResults} x Existing neighbor search results from {@linkcode findNearestNeighbors}. + * @param {number} k Number of neighbors to retain. + * If this is larger than the number of available neighbors, all neighbors are retained. + * + * @return {FindNearestNeighborsResults} Object containing the truncated search results. + */ +export function truncateNearestNeighbors(x, k) { + return gc.call( + module => module.truncate_nearest_neighbors(x.results, k), + FindNearestNeighborsResults + ); +} diff --git a/src/NeighborIndex.cpp b/src/NeighborIndex.cpp index 20cf1ab6..7078309f 100644 --- a/src/NeighborIndex.cpp +++ b/src/NeighborIndex.cpp @@ -5,6 +5,8 @@ #include "knncolle/knncolle.hpp" #include "knncolle_annoy/knncolle_annoy.hpp" +#include + std::unique_ptr, double> > create_builder(bool approximate) { std::unique_ptr, double> > builder; if (approximate) { @@ -30,9 +32,25 @@ NeighborResults find_nearest_neighbors(const NeighborIndex& index, int32_t k, in return output; } +NeighborResults truncate_nearest_neighbors(const NeighborResults& original, int32_t k) { + NeighborResults output; + size_t nobs = original.neighbors.size(); + output.neighbors.resize(nobs); + size_t desired = static_cast(k); + for (size_t i = 0; i ("NeighborIndex") diff --git a/tests/findNearestNeighbors.test.js b/tests/findNearestNeighbors.test.js index 3589628e..5a598391 100644 --- a/tests/findNearestNeighbors.test.js +++ b/tests/findNearestNeighbors.test.js @@ -60,6 +60,8 @@ test("neighbor search works with serialization", () => { // Dumping. var dump = res.serialize(); expect(dump.runs.length).toBe(ncells); + expect(dump.runs[0]).toBe(k); + expect(dump.runs[ncells-1]).toBe(k); expect(dump.indices.length).toBe(ncells * k); expect(dump.distances.length).toBe(ncells * k); @@ -92,3 +94,30 @@ test("neighbor search works with serialization", () => { buf_indices.free(); buf_distances.free(); }); + +test("neighbor search can be truncated", () => { + var ndim = 5; + var ncells = 100; + var buffer = scran.createFloat64WasmArray(ndim * ncells); + var arr = buffer.array(); + arr.forEach((x, i) => arr[i] = Math.random()); + + var index = scran.buildNeighborSearchIndex(buffer, { numberOfDims: ndim, numberOfCells: ncells }); + var k = 5; + var res = scran.findNearestNeighbors(index, k); + var dump = res.serialize(); + + var tres = scran.truncateNearestNeighbors(res, 2); + var tdump = tres.serialize(); + expect(tdump.runs.length).toBe(ncells); + expect(tdump.runs[0]).toBe(2); + expect(tdump.runs[ncells-1]).toBe(2); + expect(tdump.indices.length).toBe(ncells * 2); + expect(tdump.distances.length).toBe(ncells * 2); + + // Checking that the neighbors are the same. + expect(tdump.indices[0]).toEqual(dump.indices[0]); + expect(tdump.indices[2]).toEqual(dump.indices[5]); + expect(tdump.indices[5]).toEqual(dump.indices[11]); + expect(tdump.indices[51]).toEqual(dump.indices[126]); +}) From 94044391976372deac4959a7f43defd22701caca Mon Sep 17 00:00:00 2001 From: LTLA Date: Mon, 14 Oct 2024 14:56:32 -0700 Subject: [PATCH 2/3] Truncation also works during serialization. --- js/findNearestNeighbors.js | 24 ++++++++++++++++++------ src/NeighborIndex.h | 15 ++++++++++----- tests/findNearestNeighbors.test.js | 6 ++++++ 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/js/findNearestNeighbors.js b/js/findNearestNeighbors.js index 0ac60860..406422c8 100644 --- a/js/findNearestNeighbors.js +++ b/js/findNearestNeighbors.js @@ -122,11 +122,16 @@ export class FindNearestNeighborsResults { } /** + * @param {object} [options={}] - Optional parameters. + * @param {?number} [options.truncate=null] - Maximum number of neighbors to count for each cell. + * If `null` or greater than the number of available neighbors, all neighbors are counted. * @return {number} The total number of neighbors across all cells. * This is usually the product of the number of neighbors and the number of cells. */ - size() { - return this.#results.size(); + size(options = {}) { + const { truncate = null, ...others } = options; + utils.checkOtherOptions(others); + return this.#results.size(FindNearestNeighborsResults.#numberToTruncate(truncate)); } /** @@ -141,6 +146,10 @@ export class FindNearestNeighborsResults { return this.#results; } + static #numberToTruncate(truncate) { + return (truncate === null ? -1 : truncate); + } + /** * @param {object} [options={}] - Optional parameters. * @param {?Int32WasmArray} [options.runs=null] - A Wasm-allocated array of length equal to `numberOfCells()`, @@ -149,6 +158,8 @@ export class FindNearestNeighborsResults { * to be used to store the indices of the neighbors of each cell. * @param {?Float64WasmArray} [options.distances=null] - A Wasm-allocated array of length equal to `size()`, * to be used to store the distances to the neighbors of each cell. + * @param {?number} [options.truncate=null] - Number of nearest neighbors to serialize for each cell. + * If `null` or greater than the number of available neighbors, all neighbors are used. * * @return {object} * An object is returned with the `runs`, `indices` and `distances` keys, each with an appropriate TypedArray as the value. @@ -159,7 +170,7 @@ export class FindNearestNeighborsResults { * If only some of the arguments are non-`null`, an error is raised. */ serialize(options = {}) { - const { runs = null, indices = null, distances = null, ...others } = options; + const { runs = null, indices = null, distances = null, truncate = null, ...others } = options; utils.checkOtherOptions(others); var copy = (runs === null) + (indices === null) + (distances === null); @@ -167,6 +178,7 @@ export class FindNearestNeighborsResults { throw new Error("either all or none of 'runs', 'indices' and 'distances' can be 'null'"); } + let nkeep = FindNearestNeighborsResults.#numberToTruncate(truncate); var output; if (copy === 3) { @@ -176,10 +188,10 @@ export class FindNearestNeighborsResults { try { run_data = utils.createInt32WasmArray(this.numberOfCells()); - let s = this.size(); + let s = this.#results.size(nkeep); ind_data = utils.createInt32WasmArray(s); dist_data = utils.createFloat64WasmArray(s); - this.#results.serialize(run_data.offset, ind_data.offset, dist_data.offset); + this.#results.serialize(run_data.offset, ind_data.offset, dist_data.offset, nkeep); output = { "runs": run_data.slice(), @@ -193,7 +205,7 @@ export class FindNearestNeighborsResults { } } else { - this.#results.serialize(runs.offset, indices.offset, distances.offset); + this.#results.serialize(runs.offset, indices.offset, distances.offset, nkeep); output = { "runs": runs.array(), "indices": indices.array(), diff --git a/src/NeighborIndex.h b/src/NeighborIndex.h index e1480c09..476870bf 100644 --- a/src/NeighborIndex.h +++ b/src/NeighborIndex.h @@ -2,6 +2,7 @@ #define NEIGHBOR_INDEX_H #include +#include #include #include "knncolle/knncolle.hpp" @@ -44,10 +45,11 @@ struct NeighborResults { } public: - size_t size() const { + size_t size(int32_t truncate) const { size_t out = 0; + size_t long_truncate = truncate; for (const auto& current : neighbors) { - out += current.size(); + out += std::min(long_truncate, current.size()); } return out; } @@ -56,16 +58,19 @@ struct NeighborResults { return neighbors.size(); } - void serialize(uintptr_t runs, uintptr_t indices, uintptr_t distances) const { + void serialize(uintptr_t runs, uintptr_t indices, uintptr_t distances, int32_t truncate) const { auto rptr = reinterpret_cast(runs); auto iptr = reinterpret_cast(indices); auto dptr = reinterpret_cast(distances); + size_t long_truncate = truncate; for (const auto& current : neighbors) { - *rptr = current.size(); + size_t nkeep = std::min(long_truncate, current.size()); + *rptr = nkeep; ++rptr; - for (const auto& x : current) { + for (int32_t i = 0; i < nkeep; ++i) { + const auto& x = current[i]; *iptr = x.first; *dptr = x.second; ++iptr; diff --git a/tests/findNearestNeighbors.test.js b/tests/findNearestNeighbors.test.js index 5a598391..c2114907 100644 --- a/tests/findNearestNeighbors.test.js +++ b/tests/findNearestNeighbors.test.js @@ -120,4 +120,10 @@ test("neighbor search can be truncated", () => { expect(tdump.indices[2]).toEqual(dump.indices[5]); expect(tdump.indices[5]).toEqual(dump.indices[11]); expect(tdump.indices[51]).toEqual(dump.indices[126]); + + // Checking we get the same results with truncated serialization. + var tdump2 = res.serialize({ truncate: 2 }); + expect(tdump2.runs).toEqual(tdump.runs); + expect(tdump2.indices).toEqual(tdump.indices); + expect(tdump2.distances).toEqual(tdump.distances); }) From c5fa604dce605288ef3a0702e3f484b0681e15aa Mon Sep 17 00:00:00 2001 From: LTLA Date: Mon, 14 Oct 2024 14:57:39 -0700 Subject: [PATCH 3/3] Minor typo fix. --- js/findNearestNeighbors.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/findNearestNeighbors.js b/js/findNearestNeighbors.js index 406422c8..a296a625 100644 --- a/js/findNearestNeighbors.js +++ b/js/findNearestNeighbors.js @@ -293,7 +293,7 @@ export function findNearestNeighbors(x, k, options = {}) { * This is exactly or approximately equal to calling {@linkcode findNearestNeighbors} with the new `k`, * depending on whether `approximate = false` or `approximate = true` was used to build the search index, respectively. * - * @param {FindNEarestNeighborsResults} x Existing neighbor search results from {@linkcode findNearestNeighbors}. + * @param {FindNearestNeighborsResults} x Existing neighbor search results from {@linkcode findNearestNeighbors}. * @param {number} k Number of neighbors to retain. * If this is larger than the number of available neighbors, all neighbors are retained. *