diff --git a/umap/sparse.py b/umap/sparse.py index dac4e0bb..2c4025f8 100644 --- a/umap/sparse.py +++ b/umap/sparse.py @@ -696,8 +696,8 @@ def sparse_correlation(ind1, data1, ind2, data2, n_features): for i in range(data2.shape[0]): shifted_data2[i] = data2[i] - mu_y - norm1 = norm(shifted_data1) - norm2 = norm(shifted_data2) + norm1 = np.sqrt(norm(shifted_data1) ** 2 + (n_features - ind1.shape[0]) * mu_x ** 2) + norm2 = np.sqrt(norm(shifted_data2) ** 2 + (n_features - ind2.shape[0]) * mu_y ** 2) dot_prod_inds, dot_prod_data = sparse_mul(ind1, shifted_data1, ind2, shifted_data2) @@ -705,9 +705,22 @@ def sparse_correlation(ind1, data1, ind2, data2, n_features): if dot_prod_data.shape[0] == 0: return 1.0 + common_indices = set(dot_prod_inds) + for i in range(dot_prod_data.shape[0]): dot_product += dot_prod_data[i] + for i in range(ind1.shape[0]): + if ind1[i] not in common_indices: + dot_product -= data1[i] * (mu_y) + + for i in range(ind2.shape[0]): + if ind2[i] not in common_indices: + dot_product -= data2[i] * (mu_x) + + all_indices = arr_union(ind1, ind2) + dot_product += mu_x * mu_y * all_indices.shape[0] + if dot_product == 0.0: return 1.0 else: