diff --git a/umap/umap_.py b/umap/umap_.py index 52c3a280..695f0bf8 100644 --- a/umap/umap_.py +++ b/umap/umap_.py @@ -1044,18 +1044,24 @@ def simplicial_set_embedding( graph.sum_duplicates() n_vertices = graph.shape[1] - if n_epochs <= 0: - # For smaller datasets we can use more epochs - if graph.shape[0] <= 10000: - n_epochs = 500 - else: - n_epochs = 200 + # For smaller datasets we can use more epochs + if graph.shape[0] <= 10000: + default_epochs = 500 + else: + default_epochs = 200 - # Use more epochs for densMAP - if densmap: - n_epochs += 200 + # Use more epochs for densMAP + if densmap: + default_epochs += 200 + + if n_epochs is None: + n_epochs = default_epochs + + if n_epochs > 10: + graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0 + else: + graph.data[graph.data < (graph.data.max() / float(default_epochs))] = 0.0 - graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0 graph.eliminate_zeros() if isinstance(init, str) and init == "random": @@ -1702,9 +1708,9 @@ def _validate_parameters(self): if self.n_components < 1: raise ValueError("n_components must be greater than 0") if self.n_epochs is not None and ( - self.n_epochs <= 10 or not isinstance(self.n_epochs, int) + self.n_epochs < 0 or not isinstance(self.n_epochs, int) ): - raise ValueError("n_epochs must be a positive integer of at least 10") + raise ValueError("n_epochs must be a nonnegative integer") if self.metric_kwds is None: self._metric_kwds = {} else: @@ -2567,11 +2573,6 @@ def fit(self, X, y=None): else: self._supervised = False - if self.n_epochs is None: - n_epochs = 0 - else: - n_epochs = self.n_epochs - if self.densmap or self.output_dens: self._densmap_kwds["graph_dists"] = self.graph_dists_ @@ -2581,7 +2582,7 @@ def fit(self, X, y=None): if self.transform_mode == "embedding": self.embedding_, aux_data = self._fit_embed_data( self._raw_data[index], - n_epochs, + self.n_epochs, init, random_state, # JH why raw data? )