Skip to content

Commit

Permalink
Implement masking to control how embedded points are updated
Browse files Browse the repository at this point in the history
  • Loading branch information
matthieuheitz committed Jun 16, 2021
1 parent 42b3f1f commit c7fe35e
Show file tree
Hide file tree
Showing 3 changed files with 395 additions and 24 deletions.
307 changes: 307 additions & 0 deletions umap/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,135 @@ def _optimize_layout_euclidean_single_epoch(
)


def _optimize_layout_euclidean_masked_single_epoch(
head_embedding,
tail_embedding,
head,
tail,
mask,
n_vertices,
epochs_per_sample,
a,
b,
rng_state,
gamma,
dim,
move_other,
alpha,
epochs_per_negative_sample,
epoch_of_next_negative_sample,
epoch_of_next_sample,
n,
densmap_flag,
dens_phi_sum,
dens_re_sum,
dens_re_cov,
dens_re_std,
dens_re_mean,
dens_lambda,
dens_R,
dens_mu,
dens_mu_tot,
):
for i in numba.prange(epochs_per_sample.shape[0]):
if epoch_of_next_sample[i] <= n:
j = head[i]
k = tail[i]

current = head_embedding[j]
other = tail_embedding[k]

current_mask = mask[j]
other_mask = mask[k]

dist_squared = rdist(current, other)

if densmap_flag:
phi = 1.0 / (1.0 + a * pow(dist_squared, b))
dphi_term = (
a * b * pow(dist_squared, b - 1) / (1.0 + a * pow(dist_squared, b))
)

q_jk = phi / dens_phi_sum[k]
q_kj = phi / dens_phi_sum[j]

drk = q_jk * (
(1.0 - b * (1 - phi)) / np.exp(dens_re_sum[k]) + dphi_term
)
drj = q_kj * (
(1.0 - b * (1 - phi)) / np.exp(dens_re_sum[j]) + dphi_term
)

re_std_sq = dens_re_std * dens_re_std
weight_k = (
dens_R[k]
- dens_re_cov * (dens_re_sum[k] - dens_re_mean) / re_std_sq
)
weight_j = (
dens_R[j]
- dens_re_cov * (dens_re_sum[j] - dens_re_mean) / re_std_sq
)

grad_cor_coeff = (
dens_lambda
* dens_mu_tot
* (weight_k * drk + weight_j * drj)
/ (dens_mu[i] * dens_re_std)
/ n_vertices
)

if dist_squared > 0.0:
grad_coeff = -2.0 * a * b * pow(dist_squared, b - 1.0)
grad_coeff /= a * pow(dist_squared, b) + 1.0
else:
grad_coeff = 0.0

for d in range(dim):
grad_d = clip(grad_coeff * (current[d] - other[d]))

if densmap_flag:
grad_d += clip(2 * grad_cor_coeff * (current[d] - other[d]))

current[d] += current_mask * grad_d * alpha
if move_other:
other[d] += - other_mask * grad_d * alpha

epoch_of_next_sample[i] += epochs_per_sample[i]

n_neg_samples = int(
(n - epoch_of_next_negative_sample[i]) / epochs_per_negative_sample[i]
)

for p in range(n_neg_samples):
k = tau_rand_int(rng_state) % n_vertices

other = tail_embedding[k]

dist_squared = rdist(current, other)

if dist_squared > 0.0:
grad_coeff = 2.0 * gamma * b
grad_coeff /= (0.001 + dist_squared) * (
a * pow(dist_squared, b) + 1
)
elif j == k:
continue
else:
grad_coeff = 0.0

for d in range(dim):
if grad_coeff > 0.0:
grad_d = clip(grad_coeff * (current[d] - other[d]))
else:
grad_d = 4.0
current[d] += current_mask * grad_d * alpha

epoch_of_next_negative_sample[i] += (
n_neg_samples * epochs_per_negative_sample[i]
)



def _optimize_layout_euclidean_densmap_epoch_init(
head_embedding, tail_embedding, head, tail, a, b, re_sum, phi_sum,
):
Expand Down Expand Up @@ -379,6 +508,184 @@ def optimize_layout_euclidean(
return head_embedding


def optimize_layout_euclidean_masked(
head_embedding,
tail_embedding,
head,
tail,
mask,
n_epochs,
n_vertices,
epochs_per_sample,
a,
b,
rng_state,
gamma=1.0,
initial_alpha=1.0,
negative_sample_rate=5.0,
parallel=False,
verbose=False,
densmap=False,
densmap_kwds={},
):
"""Improve an embedding using stochastic gradient descent to minimize the
fuzzy set cross entropy between the 1-skeletons of the high dimensional
and low dimensional fuzzy simplicial sets. In practice this is done by
sampling edges based on their membership strength (with the (1-p) terms
coming from negative sampling similar to word2vec).
Parameters
----------
head_embedding: array of shape (n_samples, n_components)
The initial embedding to be improved by SGD.
tail_embedding: array of shape (source_samples, n_components)
The reference embedding of embedded points. If not embedding new
previously unseen points with respect to an existing embedding this
is simply the head_embedding (again); otherwise it provides the
existing embedding to embed with respect to.
head: array of shape (n_1_simplices)
The indices of the heads of 1-simplices with non-zero membership.
tail: array of shape (n_1_simplices)
The indices of the tails of 1-simplices with non-zero membership.
mask: array of shape (n_samples)
The weights (in [0,1]) assigned to each sample, defining how much they
should be updated. 0 means the point will not move at all, 1 means
they are updated normally. In-between values allow for fine-tuning.
n_epochs: int
The number of training epochs to use in optimization.
n_vertices: int
The number of vertices (0-simplices) in the dataset.
epochs_per_samples: array of shape (n_1_simplices)
A float value of the number of epochs per 1-simplex. 1-simplices with
weaker membership strength will have more epochs between being sampled.
a: float
Parameter of differentiable approximation of right adjoint functor
b: float
Parameter of differentiable approximation of right adjoint functor
rng_state: array of int64, shape (3,)
The internal state of the rng
gamma: float (optional, default 1.0)
Weight to apply to negative samples.
initial_alpha: float (optional, default 1.0)
Initial learning rate for the SGD.
negative_sample_rate: int (optional, default 5)
Number of negative samples to use per positive sample.
parallel: bool (optional, default False)
Whether to run the computation using numba parallel.
Running in parallel is non-deterministic, and is not used
if a random seed has been set, to ensure reproducibility.
verbose: bool (optional, default False)
Whether to report information on the current progress of the algorithm.
densmap: bool (optional, default False)
Whether to use the density-augmented densMAP objective
densmap_kwds: dict (optional, default {})
Auxiliary data for densMAP
Returns
-------
embedding: array of shape (n_samples, n_components)
The optimized embedding.
"""

dim = head_embedding.shape[1]
move_other = head_embedding.shape[0] == tail_embedding.shape[0]
alpha = initial_alpha

epochs_per_negative_sample = epochs_per_sample / negative_sample_rate
epoch_of_next_negative_sample = epochs_per_negative_sample.copy()
epoch_of_next_sample = epochs_per_sample.copy()

optimize_fn = numba.njit(
_optimize_layout_euclidean_masked_single_epoch, fastmath=True, parallel=parallel
)

if densmap:
dens_init_fn = numba.njit(
_optimize_layout_euclidean_densmap_epoch_init,
fastmath=True,
parallel=parallel,
)

dens_mu_tot = np.sum(densmap_kwds["mu_sum"]) / 2
dens_lambda = densmap_kwds["lambda"]
dens_R = densmap_kwds["R"]
dens_mu = densmap_kwds["mu"]
dens_phi_sum = np.zeros(n_vertices, dtype=np.float32)
dens_re_sum = np.zeros(n_vertices, dtype=np.float32)
dens_var_shift = densmap_kwds["var_shift"]
else:
dens_mu_tot = 0
dens_lambda = 0
dens_R = np.zeros(1, dtype=np.float32)
dens_mu = np.zeros(1, dtype=np.float32)
dens_phi_sum = np.zeros(1, dtype=np.float32)
dens_re_sum = np.zeros(1, dtype=np.float32)

for n in range(n_epochs):

densmap_flag = (
densmap
and (densmap_kwds["lambda"] > 0)
and (((n + 1) / float(n_epochs)) > (1 - densmap_kwds["frac"]))
)

if densmap_flag:
dens_init_fn(
head_embedding,
tail_embedding,
head,
tail,
a,
b,
dens_re_sum,
dens_phi_sum,
)

dens_re_std = np.sqrt(np.var(dens_re_sum) + dens_var_shift)
dens_re_mean = np.mean(dens_re_sum)
dens_re_cov = np.dot(dens_re_sum, dens_R) / (n_vertices - 1)
else:
dens_re_std = 0
dens_re_mean = 0
dens_re_cov = 0

optimize_fn(
head_embedding,
tail_embedding,
head,
tail,
mask,
n_vertices,
epochs_per_sample,
a,
b,
rng_state,
gamma,
dim,
move_other,
alpha,
epochs_per_negative_sample,
epoch_of_next_negative_sample,
epoch_of_next_sample,
n,
densmap_flag,
dens_phi_sum,
dens_re_sum,
dens_re_cov,
dens_re_std,
dens_re_mean,
dens_lambda,
dens_R,
dens_mu,
dens_mu_tot,
)

alpha = initial_alpha * (1.0 - (float(n) / float(n_epochs)))

if verbose and n % int(n_epochs / 10) == 0:
print("\tcompleted ", n, " / ", n_epochs, "epochs")

return head_embedding


@numba.njit(fastmath=True)
def optimize_layout_generic(
head_embedding,
Expand Down
8 changes: 7 additions & 1 deletion umap/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _compile_model(self):
run_eagerly=self.run_eagerly,
)

def _fit_embed_data(self, X, n_epochs, init, random_state):
def _fit_embed_data(self, X, n_epochs, init, random_state, pin_mask):

if self.metric == "precomputed":
X = self._X
Expand All @@ -371,6 +371,12 @@ def _fit_embed_data(self, X, n_epochs, init, random_state):
if len(self.dims) > 1:
X = np.reshape(X, [len(X)] + list(self.dims))

if pin_mask is not None:
warn(
"Pinning is not yet supported by Parametric UMAP.\
Ignoring the pin_mask."
)

if self.parametric_reconstruction and (np.max(X) > 1.0 or np.min(X) < 0.0):
warn(
"Data should be scaled to the range 0-1 for cross-entropy reconstruction loss."
Expand Down
Loading

0 comments on commit c7fe35e

Please sign in to comment.