diff --git a/umap/plot.py b/umap/plot.py index f7a8af4f..3979bac5 100644 --- a/umap/plot.py +++ b/umap/plot.py @@ -166,14 +166,24 @@ def _nhood_search(umap_object, nhood_size): else: rng_state = np.empty(3, dtype=np.int64) + if len(umap_object._metric_kwds) >= 1: + _dist = umap_object._input_distance_func + _args = tuple(umap_object._metric_kwds.values()) + + @numba.njit() + def _metric(x, y): + _dist(x, y, *_args) + + else: + _metric = umap_object._input_distance_func + init = initialise_search( umap_object._rp_forest, umap_object._raw_data, umap_object._raw_data, int(nhood_size * umap_object.transform_queue_size), rng_state, - umap_object._distance_func, - umap_object._dist_args, + _metric, ) result = initialized_nnd_search( @@ -182,8 +192,7 @@ def _nhood_search(umap_object, nhood_size): umap_object._search_graph.indices, init, umap_object._raw_data, - umap_object._distance_func, - umap_object._dist_args, + _metric, ) indices, dists = deheap_sort(result)