Skip to content

Commit

Permalink
Fix for #398
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Apr 10, 2020
1 parent 791085d commit 5809c18
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions umap/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,24 @@ def _nhood_search(umap_object, nhood_size):
else:
rng_state = np.empty(3, dtype=np.int64)

if len(umap_object._metric_kwds) >= 1:
_dist = umap_object._input_distance_func
_args = tuple(umap_object._metric_kwds.values())

@numba.njit()
def _metric(x, y):
_dist(x, y, *_args)

else:
_metric = umap_object._input_distance_func

init = initialise_search(
umap_object._rp_forest,
umap_object._raw_data,
umap_object._raw_data,
int(nhood_size * umap_object.transform_queue_size),
rng_state,
umap_object._distance_func,
umap_object._dist_args,
_metric,
)

result = initialized_nnd_search(
Expand All @@ -182,8 +192,7 @@ def _nhood_search(umap_object, nhood_size):
umap_object._search_graph.indices,
init,
umap_object._raw_data,
umap_object._distance_func,
umap_object._dist_args,
_metric,
)

indices, dists = deheap_sort(result)
Expand Down

0 comments on commit 5809c18

Please sign in to comment.