Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added utility to truncate the neighbor search results. #93

Merged
merged 3 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions js/findNearestNeighbors.js
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,16 @@ export class FindNearestNeighborsResults {
}

/**
* @param {object} [options={}] - Optional parameters.
* @param {?number} [options.truncate=null] - Maximum number of neighbors to count for each cell.
* If `null` or greater than the number of available neighbors, all neighbors are counted.
* @return {number} The total number of neighbors across all cells.
* This is usually the product of the number of neighbors and the number of cells.
*/
size() {
return this.#results.size();
size(options = {}) {
const { truncate = null, ...others } = options;
utils.checkOtherOptions(others);
return this.#results.size(FindNearestNeighborsResults.#numberToTruncate(truncate));
}

/**
Expand All @@ -141,6 +146,10 @@ export class FindNearestNeighborsResults {
return this.#results;
}

static #numberToTruncate(truncate) {
return (truncate === null ? -1 : truncate);
}

/**
* @param {object} [options={}] - Optional parameters.
* @param {?Int32WasmArray} [options.runs=null] - A Wasm-allocated array of length equal to `numberOfCells()`,
Expand All @@ -149,6 +158,8 @@ export class FindNearestNeighborsResults {
* to be used to store the indices of the neighbors of each cell.
* @param {?Float64WasmArray} [options.distances=null] - A Wasm-allocated array of length equal to `size()`,
* to be used to store the distances to the neighbors of each cell.
* @param {?number} [options.truncate=null] - Number of nearest neighbors to serialize for each cell.
* If `null` or greater than the number of available neighbors, all neighbors are used.
*
* @return {object}
* An object is returned with the `runs`, `indices` and `distances` keys, each with an appropriate TypedArray as the value.
Expand All @@ -159,14 +170,15 @@ export class FindNearestNeighborsResults {
* If only some of the arguments are non-`null`, an error is raised.
*/
serialize(options = {}) {
const { runs = null, indices = null, distances = null, ...others } = options;
const { runs = null, indices = null, distances = null, truncate = null, ...others } = options;
utils.checkOtherOptions(others);

var copy = (runs === null) + (indices === null) + (distances === null);
if (copy != 3 && copy != 0) {
throw new Error("either all or none of 'runs', 'indices' and 'distances' can be 'null'");
}

let nkeep = FindNearestNeighborsResults.#numberToTruncate(truncate);
var output;

if (copy === 3) {
Expand All @@ -176,10 +188,10 @@ export class FindNearestNeighborsResults {

try {
run_data = utils.createInt32WasmArray(this.numberOfCells());
let s = this.size();
let s = this.#results.size(nkeep);
ind_data = utils.createInt32WasmArray(s);
dist_data = utils.createFloat64WasmArray(s);
this.#results.serialize(run_data.offset, ind_data.offset, dist_data.offset);
this.#results.serialize(run_data.offset, ind_data.offset, dist_data.offset, nkeep);

output = {
"runs": run_data.slice(),
Expand All @@ -193,7 +205,7 @@ export class FindNearestNeighborsResults {
}

} else {
this.#results.serialize(runs.offset, indices.offset, distances.offset);
this.#results.serialize(runs.offset, indices.offset, distances.offset, nkeep);
output = {
"runs": runs.array(),
"indices": indices.array(),
Expand Down Expand Up @@ -275,3 +287,21 @@ export function findNearestNeighbors(x, k, options = {}) {
FindNearestNeighborsResults
);
}

/**
* Truncate existing neighbor search results to the `k` nearest neighbors for each cell.
* This is exactly or approximately equal to calling {@linkcode findNearestNeighbors} with the new `k`,
* depending on whether `approximate = false` or `approximate = true` was used to build the search index, respectively.
*
* @param {FindNearestNeighborsResults} x Existing neighbor search results from {@linkcode findNearestNeighbors}.
* @param {number} k Number of neighbors to retain.
* If this is larger than the number of available neighbors, all neighbors are retained.
*
* @return {FindNearestNeighborsResults} Object containing the truncated search results.
*/
export function truncateNearestNeighbors(x, k) {
return gc.call(
module => module.truncate_nearest_neighbors(x.results, k),
FindNearestNeighborsResults
);
}
18 changes: 18 additions & 0 deletions src/NeighborIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "knncolle/knncolle.hpp"
#include "knncolle_annoy/knncolle_annoy.hpp"

#include <algorithm>

std::unique_ptr<knncolle::Builder<knncolle::SimpleMatrix<int32_t, int32_t, double>, double> > create_builder(bool approximate) {
std::unique_ptr<knncolle::Builder<knncolle::SimpleMatrix<int32_t, int32_t, double>, double> > builder;
if (approximate) {
Expand All @@ -30,9 +32,25 @@ NeighborResults find_nearest_neighbors(const NeighborIndex& index, int32_t k, in
return output;
}

NeighborResults truncate_nearest_neighbors(const NeighborResults& original, int32_t k) {
NeighborResults output;
size_t nobs = original.neighbors.size();
output.neighbors.resize(nobs);
size_t desired = static_cast<size_t>(k);
for (size_t i = 0; i <nobs; ++i) {
const auto& current = original.neighbors[i];
auto& curout = output.neighbors[i];
size_t size = std::min(current.size(), desired);
curout.insert(curout.end(), current.begin(), current.begin() + size);
}
return output;
}

EMSCRIPTEN_BINDINGS(build_neighbor_index) {
emscripten::function("find_nearest_neighbors", &find_nearest_neighbors, emscripten::return_value_policy::take_ownership());

emscripten::function("truncate_nearest_neighbors", &truncate_nearest_neighbors, emscripten::return_value_policy::take_ownership());

emscripten::function("build_neighbor_index", &build_neighbor_index, emscripten::return_value_policy::take_ownership());

emscripten::class_<NeighborIndex>("NeighborIndex")
Expand Down
15 changes: 10 additions & 5 deletions src/NeighborIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define NEIGHBOR_INDEX_H

#include <memory>
#include <algorithm>
#include <vector>

#include "knncolle/knncolle.hpp"
Expand Down Expand Up @@ -44,10 +45,11 @@ struct NeighborResults {
}

public:
size_t size() const {
size_t size(int32_t truncate) const {
size_t out = 0;
size_t long_truncate = truncate;
for (const auto& current : neighbors) {
out += current.size();
out += std::min(long_truncate, current.size());
}
return out;
}
Expand All @@ -56,16 +58,19 @@ struct NeighborResults {
return neighbors.size();
}

void serialize(uintptr_t runs, uintptr_t indices, uintptr_t distances) const {
void serialize(uintptr_t runs, uintptr_t indices, uintptr_t distances, int32_t truncate) const {
auto rptr = reinterpret_cast<int32_t*>(runs);
auto iptr = reinterpret_cast<int32_t*>(indices);
auto dptr = reinterpret_cast<double*>(distances);

size_t long_truncate = truncate;
for (const auto& current : neighbors) {
*rptr = current.size();
size_t nkeep = std::min(long_truncate, current.size());
*rptr = nkeep;
++rptr;

for (const auto& x : current) {
for (int32_t i = 0; i < nkeep; ++i) {
const auto& x = current[i];
*iptr = x.first;
*dptr = x.second;
++iptr;
Expand Down
35 changes: 35 additions & 0 deletions tests/findNearestNeighbors.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ test("neighbor search works with serialization", () => {
// Dumping.
var dump = res.serialize();
expect(dump.runs.length).toBe(ncells);
expect(dump.runs[0]).toBe(k);
expect(dump.runs[ncells-1]).toBe(k);
expect(dump.indices.length).toBe(ncells * k);
expect(dump.distances.length).toBe(ncells * k);

Expand Down Expand Up @@ -92,3 +94,36 @@ test("neighbor search works with serialization", () => {
buf_indices.free();
buf_distances.free();
});

test("neighbor search can be truncated", () => {
var ndim = 5;
var ncells = 100;
var buffer = scran.createFloat64WasmArray(ndim * ncells);
var arr = buffer.array();
arr.forEach((x, i) => arr[i] = Math.random());

var index = scran.buildNeighborSearchIndex(buffer, { numberOfDims: ndim, numberOfCells: ncells });
var k = 5;
var res = scran.findNearestNeighbors(index, k);
var dump = res.serialize();

var tres = scran.truncateNearestNeighbors(res, 2);
var tdump = tres.serialize();
expect(tdump.runs.length).toBe(ncells);
expect(tdump.runs[0]).toBe(2);
expect(tdump.runs[ncells-1]).toBe(2);
expect(tdump.indices.length).toBe(ncells * 2);
expect(tdump.distances.length).toBe(ncells * 2);

// Checking that the neighbors are the same.
expect(tdump.indices[0]).toEqual(dump.indices[0]);
expect(tdump.indices[2]).toEqual(dump.indices[5]);
expect(tdump.indices[5]).toEqual(dump.indices[11]);
expect(tdump.indices[51]).toEqual(dump.indices[126]);

// Checking we get the same results with truncated serialization.
var tdump2 = res.serialize({ truncate: 2 });
expect(tdump2.runs).toEqual(tdump.runs);
expect(tdump2.indices).toEqual(tdump.indices);
expect(tdump2.distances).toEqual(tdump.distances);
})