Skip to content

Commit

Permalink
Enable n_epochs=0 to get the initial embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
jondo committed Mar 30, 2021
1 parent f86c922 commit e442bcd
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,18 +1044,24 @@ def simplicial_set_embedding(
graph.sum_duplicates()
n_vertices = graph.shape[1]

if n_epochs <= 0:
# For smaller datasets we can use more epochs
if graph.shape[0] <= 10000:
n_epochs = 500
else:
n_epochs = 200
# For smaller datasets we can use more epochs
if graph.shape[0] <= 10000:
default_epochs = 500
else:
default_epochs = 200

# Use more epochs for densMAP
if densmap:
n_epochs += 200
# Use more epochs for densMAP
if densmap:
default_epochs += 200

if n_epochs is None:
n_epochs = default_epochs

if n_epochs > 10:
graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0
else:
graph.data[graph.data < (graph.data.max() / float(default_epochs))] = 0.0

graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0
graph.eliminate_zeros()

if isinstance(init, str) and init == "random":
Expand Down Expand Up @@ -1702,9 +1708,9 @@ def _validate_parameters(self):
if self.n_components < 1:
raise ValueError("n_components must be greater than 0")
if self.n_epochs is not None and (
self.n_epochs <= 10 or not isinstance(self.n_epochs, int)
self.n_epochs < 0 or not isinstance(self.n_epochs, int)
):
raise ValueError("n_epochs must be a positive integer of at least 10")
raise ValueError("n_epochs must be a nonnegative integer")
if self.metric_kwds is None:
self._metric_kwds = {}
else:
Expand Down Expand Up @@ -2567,11 +2573,6 @@ def fit(self, X, y=None):
else:
self._supervised = False

if self.n_epochs is None:
n_epochs = 0
else:
n_epochs = self.n_epochs

if self.densmap or self.output_dens:
self._densmap_kwds["graph_dists"] = self.graph_dists_

Expand All @@ -2581,7 +2582,7 @@ def fit(self, X, y=None):
if self.transform_mode == "embedding":
self.embedding_, aux_data = self._fit_embed_data(
self._raw_data[index],
n_epochs,
self.n_epochs,
init,
random_state, # JH why raw data?
)
Expand Down

0 comments on commit e442bcd

Please sign in to comment.