diff --git a/requirements.txt b/requirements.txt index eb327de..6503090 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ ConfigArgParse pandas seaborn tqdm -scvi==0.6.7 +scvi-tools leidenalg torch scanpy diff --git a/setup.py b/setup.py index ce7b8cd..7c9aaf6 100644 --- a/setup.py +++ b/setup.py @@ -3,38 +3,36 @@ from setuptools import setup, find_packages if sys.version_info < (3,): - sys.exit('solo requires Python >= 3.6') + sys.exit("solo requires Python >= 3.6") try: from solo import __author__, __email__ except ImportError: # Deps not yet installed - __author__ = __email__ = '' + __author__ = __email__ = "" setup( - name='solo-sc', - version='0.6', - description='Neural network classifiers for doublets', - long_description=Path('README.md').read_text('utf-8'), + name="solo-sc", + version="1.0", + description="Neural network classifiers for doublets", + long_description=Path("README.md").read_text("utf-8"), long_description_content_type="text/markdown", - url='http://github.com/calico/solo', - download_url='https://github.com/calico/solo/archive/0.1.tar.gz', + url="http://github.com/calico/solo", + download_url="https://github.com/calico/solo/archive/1.0.tar.gz", author=__author__, author_email=__email__, - license='Apache', - python_requires='>=3.6', + license="Apache", + python_requires=">=3.6", install_requires=[ - l.strip() for l in - Path('requirements.txt').read_text('utf-8').splitlines() - ], + l.strip() for l in Path("requirements.txt").read_text("utf-8").splitlines() + ], packages=find_packages(), entry_points=dict( - console_scripts=['solo=solo.solo:main', - 'hashsolo=solo.hashsolo:main'], + console_scripts=["solo=solo.solo:main", "hashsolo=solo.hashsolo:main"], ), classifiers=[ - 'Environment :: Console', - 'Intended Audience :: Science/Research', - 'Topic :: Scientific/Engineering :: Bio-Informatics', + "Environment :: Console", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Bio-Informatics", ], ) diff --git a/solo/__init__.py b/solo/__init__.py index 8ebe692..e53900e 100644 --- a/solo/__init__.py +++ b/solo/__init__.py @@ -1,5 +1,5 @@ -__author__ = 'David Kelley, Nick Bernstein' -__email__ = 'nicholas@calicolabs.com' -__version__ = '0.1' +__author__ = "David Kelley, Nick Bernstein" +__email__ = "nicholas@calicolabs.com" +__version__ = "0.1" from . import hashsolo, utils diff --git a/solo/hashsolo.py b/solo/hashsolo.py index 3db1617..0a3f506 100644 --- a/solo/hashsolo.py +++ b/solo/hashsolo.py @@ -13,7 +13,7 @@ from scipy.sparse import issparse from sklearn.metrics import calinski_harabasz_score -''' +""" HashSolo script provides a probabilistic cell hashing demultiplexing method which generates a noise distribution and signal distribution for each hashing barcode from empirically observed counts. These distributions @@ -28,12 +28,11 @@ second highest barcode from a noise distribution. A negative two highest barcodes should come from noise distributions. We test each of these hypotheses in a bayesian fashion, and select the most probable hypothesis. -''' +""" -def _calculate_log_likelihoods(data, - number_of_noise_barcodes): - '''Calculate log likelihoods for each hypothesis, negative, singlet, doublet +def _calculate_log_likelihoods(data, number_of_noise_barcodes): + """Calculate log likelihoods for each hypothesis, negative, singlet, doublet Parameters ---------- @@ -47,9 +46,10 @@ def _calculate_log_likelihoods(data, a 2d np.array log likelihood of each hypothesis all_indices counter_to_barcode_combo - ''' + """ + def gaussian_updates(data, mu_o, std_o): - '''Update parameters of your gaussian + """Update parameters of your gaussian https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf Parameters ---------- @@ -65,13 +65,15 @@ def gaussian_updates(data, mu_o, std_o): mean of gaussian float std of gaussian - ''' - lam_o = 1/(std_o**2) + """ + lam_o = 1 / (std_o ** 2) n = len(data) - lam = 1/np.var(data) if len(data) > 1 else lam_o - lam_n = lam_o + n*lam - mu_n = (np.mean(data)*n*lam + mu_o*lam_o)/lam_n if len(data) > 0 else mu_o - return mu_n, (1 / (lam_n / (n + 1)))**(1/2) + lam = 1 / np.var(data) if len(data) > 1 else lam_o + lam_n = lam_o + n * lam + mu_n = ( + (np.mean(data) * n * lam + mu_o * lam_o) / lam_n if len(data) > 0 else mu_o + ) + return mu_n, (1 / (lam_n / (n + 1))) ** (1 / 2) eps = 1e-15 # probabilites for negative, singlet, doublets @@ -79,7 +81,11 @@ def gaussian_updates(data, mu_o, std_o): all_indices = np.empty(data.shape[0]) num_of_barcodes = data.shape[1] - number_of_non_noise_barcodes = num_of_barcodes - number_of_noise_barcodes if number_of_noise_barcodes is not None else 2 + number_of_non_noise_barcodes = ( + num_of_barcodes - number_of_noise_barcodes + if number_of_noise_barcodes is not None + else 2 + ) num_of_noise_barcodes = num_of_barcodes - number_of_non_noise_barcodes # assume log normal @@ -92,10 +98,12 @@ def gaussian_updates(data, mu_o, std_o): # barcodes with rank < k are considered to be noise global_signal_counts = np.ravel(data_sort[:, -1]) global_noise_counts = np.ravel(data_sort[:, :-number_of_non_noise_barcodes]) - global_mu_signal_o, global_sigma_signal_o = np.mean( - global_signal_counts), np.std(global_signal_counts) - global_mu_noise_o, global_sigma_noise_o = np.mean( - global_noise_counts), np.std(global_noise_counts) + global_mu_signal_o, global_sigma_signal_o = np.mean(global_signal_counts), np.std( + global_signal_counts + ) + global_mu_noise_o, global_sigma_noise_o = np.mean(global_noise_counts), np.std( + global_noise_counts + ) noise_params_dict = {} signal_params_dict = {} @@ -103,7 +111,9 @@ def gaussian_updates(data, mu_o, std_o): # for each barcode get empirical noise and signal distribution parameterization for x in np.arange(num_of_barcodes): sample_barcodes = data[:, x] - sample_barcodes_noise_idx = np.where(data_arg[:, :num_of_noise_barcodes] == x)[0] + sample_barcodes_noise_idx = np.where(data_arg[:, :num_of_noise_barcodes] == x)[ + 0 + ] sample_barcodes_signal_idx = np.where(data_arg[:, -1] == x) # get noise and signal counts @@ -111,8 +121,12 @@ def gaussian_updates(data, mu_o, std_o): signal_counts = sample_barcodes[sample_barcodes_signal_idx] # get parameters of distribution, assuming lognormal do update from global values - noise_param = gaussian_updates(noise_counts, global_mu_noise_o, global_sigma_noise_o) - signal_param = gaussian_updates(signal_counts, global_mu_signal_o, global_sigma_signal_o) + noise_param = gaussian_updates( + noise_counts, global_mu_noise_o, global_sigma_noise_o + ) + signal_param = gaussian_updates( + signal_counts, global_mu_signal_o, global_sigma_signal_o + ) noise_params_dict[x] = noise_param signal_params_dict[x] = signal_param @@ -120,17 +134,17 @@ def gaussian_updates(data, mu_o, std_o): counter = 0 # for each combination of noise and signal barcode calculate probiltiy of in silico and real cell hypotheses - for noise_sample_idx, signal_sample_idx in product(np.arange(num_of_barcodes), - np.arange(num_of_barcodes)): - signal_subset = (data_arg[:, -1] == signal_sample_idx) - noise_subset = (data_arg[:, -2] == noise_sample_idx) - subset = (signal_subset & noise_subset) + for noise_sample_idx, signal_sample_idx in product( + np.arange(num_of_barcodes), np.arange(num_of_barcodes) + ): + signal_subset = data_arg[:, -1] == signal_sample_idx + noise_subset = data_arg[:, -2] == noise_sample_idx + subset = signal_subset & noise_subset if sum(subset) == 0: continue indices = np.where(subset)[0] - barcode_combo = '_'.join([str(noise_sample_idx), - str(signal_sample_idx)]) + barcode_combo = "_".join([str(noise_sample_idx), str(signal_sample_idx)]) all_indices[np.where(subset)[0]] = counter counter_to_barcode_combo[counter] = barcode_combo counter += 1 @@ -139,20 +153,54 @@ def gaussian_updates(data, mu_o, std_o): # calculate probabilties for each hypothesis for each cell data_subset = data[subset] - log_signal_signal_probs = np.log(norm.pdf( - data_subset[:, signal_sample_idx], *signal_params[:-2], loc=signal_params[-2], scale=signal_params[-1]) + eps) + log_signal_signal_probs = np.log( + norm.pdf( + data_subset[:, signal_sample_idx], + *signal_params[:-2], + loc=signal_params[-2], + scale=signal_params[-1] + ) + + eps + ) signal_noise_params = signal_params_dict[noise_sample_idx] - log_noise_signal_probs = np.log(norm.pdf( - data_subset[:, noise_sample_idx], *signal_noise_params[:-2], loc=signal_noise_params[-2], scale=signal_noise_params[-1]) + eps) - - log_noise_noise_probs = np.log(norm.pdf( - data_subset[:, noise_sample_idx], *noise_params[:-2], loc=noise_params[-2], scale=noise_params[-1]) + eps) - log_signal_noise_probs = np.log(norm.pdf( - data_subset[:, signal_sample_idx], *noise_params[:-2], loc=noise_params[-2], scale=noise_params[-1]) + eps) - - probs_of_negative = np.sum([log_noise_noise_probs, log_signal_noise_probs], axis=0) - probs_of_singlet = np.sum([log_noise_noise_probs, log_signal_signal_probs], axis=0) - probs_of_doublet = np.sum([log_noise_signal_probs, log_signal_signal_probs], axis=0) + log_noise_signal_probs = np.log( + norm.pdf( + data_subset[:, noise_sample_idx], + *signal_noise_params[:-2], + loc=signal_noise_params[-2], + scale=signal_noise_params[-1] + ) + + eps + ) + + log_noise_noise_probs = np.log( + norm.pdf( + data_subset[:, noise_sample_idx], + *noise_params[:-2], + loc=noise_params[-2], + scale=noise_params[-1] + ) + + eps + ) + log_signal_noise_probs = np.log( + norm.pdf( + data_subset[:, signal_sample_idx], + *noise_params[:-2], + loc=noise_params[-2], + scale=noise_params[-1] + ) + + eps + ) + + probs_of_negative = np.sum( + [log_noise_noise_probs, log_signal_noise_probs], axis=0 + ) + probs_of_singlet = np.sum( + [log_noise_noise_probs, log_signal_signal_probs], axis=0 + ) + probs_of_doublet = np.sum( + [log_noise_signal_probs, log_signal_signal_probs], axis=0 + ) log_probs_list = [probs_of_negative, probs_of_singlet, probs_of_doublet] # each cell and each hypothesis probability @@ -162,7 +210,7 @@ def gaussian_updates(data, mu_o, std_o): def _calculate_bayes_rule(data, priors, number_of_noise_barcodes): - ''' + """ Calculate bayes rule from log likelihoods Parameters @@ -185,20 +233,28 @@ def _calculate_bayes_rule(data, priors, number_of_noise_barcodes): 'most_likely_hypothesis' key is a 1d np.array of the most likely hypothesis 'probs_hypotheses' key is a 2d np.array probability of each hypothesis 'log_likelihoods_for_each_hypothesis' key is a 2d np.array log likelihood of each hypothesis - ''' + """ priors = np.array(priors) - log_likelihoods_for_each_hypothesis, _, _ = _calculate_log_likelihoods(data, number_of_noise_barcodes) - probs_hypotheses = np.exp(log_likelihoods_for_each_hypothesis) * priors / np.sum( - np.multiply(np.exp(log_likelihoods_for_each_hypothesis), priors), axis=1)[:, None] + log_likelihoods_for_each_hypothesis, _, _ = _calculate_log_likelihoods( + data, number_of_noise_barcodes + ) + probs_hypotheses = ( + np.exp(log_likelihoods_for_each_hypothesis) + * priors + / np.sum( + np.multiply(np.exp(log_likelihoods_for_each_hypothesis), priors), axis=1 + )[:, None] + ) most_likely_hypothesis = np.argmax(probs_hypotheses, axis=1) - return {'most_likely_hypothesis': most_likely_hypothesis, - 'probs_hypotheses': probs_hypotheses, - 'log_likelihoods_for_each_hypothesis': log_likelihoods_for_each_hypothesis} + return { + "most_likely_hypothesis": most_likely_hypothesis, + "probs_hypotheses": probs_hypotheses, + "log_likelihoods_for_each_hypothesis": log_likelihoods_for_each_hypothesis, + } -def _get_clusters(clustering_data: anndata.AnnData, - resolutions: list): - ''' +def _get_clusters(clustering_data: anndata.AnnData, resolutions: list): + """ Principled cell clustering Parameters ---------- @@ -210,13 +266,15 @@ def _get_clusters(clustering_data: anndata.AnnData, ------- np.ndarray leiden clustering results for each cell - ''' + """ sc.pp.normalize_per_cell(clustering_data, counts_per_cell_after=1e4) sc.pp.log1p(clustering_data) - sc.pp.highly_variable_genes(clustering_data, min_mean=0.0125, max_mean=3, min_disp=0.5) - clustering_data = clustering_data[:, clustering_data.var['highly_variable']] + sc.pp.highly_variable_genes( + clustering_data, min_mean=0.0125, max_mean=3, min_disp=0.5 + ) + clustering_data = clustering_data[:, clustering_data.var["highly_variable"]] sc.pp.scale(clustering_data, max_value=10) - sc.tl.pca(clustering_data, svd_solver='arpack') + sc.tl.pca(clustering_data, svd_solver="arpack") sc.pp.neighbors(clustering_data, n_neighbors=10, n_pcs=40) sc.tl.umap(clustering_data) best_ch_score = -np.inf @@ -224,23 +282,26 @@ def _get_clusters(clustering_data: anndata.AnnData, for resolution in resolutions: sc.tl.leiden(clustering_data, resolution=resolution) - ch_score = calinski_harabasz_score(clustering_data.X, clustering_data.obs['leiden']) + ch_score = calinski_harabasz_score( + clustering_data.X, clustering_data.obs["leiden"] + ) if ch_score > best_ch_score: - clustering_data.obs['best_leiden'] = clustering_data.obs['leiden'].values + clustering_data.obs["best_leiden"] = clustering_data.obs["leiden"].values best_ch_score = ch_score - return clustering_data.obs['best_leiden'].values + return clustering_data.obs["best_leiden"].values -def hashsolo(cell_hashing_adata: anndata.AnnData, - priors: list = [.01, .8, .19], - pre_existing_clusters: str = None, - clustering_data: anndata.AnnData = None, - resolutions: list = [.1, .25, .5, .75, 1], - number_of_noise_barcodes: int = None, - inplace: bool = True, - ): - '''Demultiplex cell hashing dataset using HashSolo method +def hashsolo( + cell_hashing_adata: anndata.AnnData, + priors: list = [0.01, 0.8, 0.19], + pre_existing_clusters: str = None, + clustering_data: anndata.AnnData = None, + resolutions: list = [0.1, 0.25, 0.5, 0.75, 1], + number_of_noise_barcodes: int = None, + inplace: bool = True, +): + """Demultiplex cell hashing dataset using HashSolo method Parameters ---------- @@ -268,77 +329,117 @@ def hashsolo(cell_hashing_adata: anndata.AnnData, cell_hashing_adata : AnnData if inplace is False returns AnnData with demultiplexing results in .obs attribute otherwise does is in place - ''' + """ if issparse(cell_hashing_adata.X): cell_hashing_adata.X = np.array(cell_hashing_adata.X.todense()) - + if clustering_data is not None: - print('This may take awhile we are running clustering at {} different resolutions'.format(len(resolutions))) + print( + "This may take awhile we are running clustering at {} different resolutions".format( + len(resolutions) + ) + ) if not all(clustering_data.obs_names == cell_hashing_adata.obs_names): raise ValueError( - 'clustering_data and cell hashing cell_hashing_adata must have same index') - cell_hashing_adata.obs['best_leiden'] = _get_clusters(clustering_data, resolutions) + "clustering_data and cell hashing cell_hashing_adata must have same index" + ) + cell_hashing_adata.obs["best_leiden"] = _get_clusters( + clustering_data, resolutions + ) data = cell_hashing_adata.X num_of_cells = cell_hashing_adata.shape[0] - results = pd.DataFrame(np.zeros((num_of_cells, 6)), - columns=['most_likely_hypothesis', - 'probs_hypotheses', - 'cluster_feature', - 'negative_hypothesis_probability', - 'singlet_hypothesis_probability', - 'doublet_hypothesis_probability', ], - index=cell_hashing_adata.obs_names) + results = pd.DataFrame( + np.zeros((num_of_cells, 6)), + columns=[ + "most_likely_hypothesis", + "probs_hypotheses", + "cluster_feature", + "negative_hypothesis_probability", + "singlet_hypothesis_probability", + "doublet_hypothesis_probability", + ], + index=cell_hashing_adata.obs_names, + ) if clustering_data is not None or pre_existing_clusters is not None: - cluster_features = 'best_leiden' if pre_existing_clusters is None else pre_existing_clusters + cluster_features = ( + "best_leiden" if pre_existing_clusters is None else pre_existing_clusters + ) unique_cluster_features = np.unique(cell_hashing_adata.obs[cluster_features]) for cluster_feature in unique_cluster_features: - cluster_feature_bool_vector = cell_hashing_adata.obs[cluster_features] == cluster_feature - posterior_dict = _calculate_bayes_rule(data[cluster_feature_bool_vector], priors, number_of_noise_barcodes) - results.loc[cluster_feature_bool_vector, - 'most_likely_hypothesis'] = posterior_dict['most_likely_hypothesis'] - results.loc[cluster_feature_bool_vector, 'cluster_feature'] = cluster_feature - results.loc[cluster_feature_bool_vector, - 'negative_hypothesis_probability'] = posterior_dict['probs_hypotheses'][:, 0] - results.loc[cluster_feature_bool_vector, - 'singlet_hypothesis_probability'] = posterior_dict['probs_hypotheses'][:, 1] - results.loc[cluster_feature_bool_vector, - 'doublet_hypothesis_probability'] = posterior_dict['probs_hypotheses'][:, 2] + cluster_feature_bool_vector = ( + cell_hashing_adata.obs[cluster_features] == cluster_feature + ) + posterior_dict = _calculate_bayes_rule( + data[cluster_feature_bool_vector], priors, number_of_noise_barcodes + ) + results.loc[ + cluster_feature_bool_vector, "most_likely_hypothesis" + ] = posterior_dict["most_likely_hypothesis"] + results.loc[ + cluster_feature_bool_vector, "cluster_feature" + ] = cluster_feature + results.loc[ + cluster_feature_bool_vector, "negative_hypothesis_probability" + ] = posterior_dict["probs_hypotheses"][:, 0] + results.loc[ + cluster_feature_bool_vector, "singlet_hypothesis_probability" + ] = posterior_dict["probs_hypotheses"][:, 1] + results.loc[ + cluster_feature_bool_vector, "doublet_hypothesis_probability" + ] = posterior_dict["probs_hypotheses"][:, 2] else: posterior_dict = _calculate_bayes_rule(data, priors, number_of_noise_barcodes) - results.loc[:, 'most_likely_hypothesis'] = posterior_dict['most_likely_hypothesis'] - results.loc[:, 'cluster_feature'] = 0 - results.loc[:, 'negative_hypothesis_probability'] = posterior_dict['probs_hypotheses'][:, 0] - results.loc[:, 'singlet_hypothesis_probability'] = posterior_dict['probs_hypotheses'][:, 1] - results.loc[:, 'doublet_hypothesis_probability'] = posterior_dict['probs_hypotheses'][:, 2] - - cell_hashing_adata.obs['most_likely_hypothesis'] = results.loc[cell_hashing_adata.obs_names, - 'most_likely_hypothesis'] - cell_hashing_adata.obs['cluster_feature'] = results.loc[cell_hashing_adata.obs_names, 'cluster_feature'] - cell_hashing_adata.obs['negative_hypothesis_probability'] = results.loc[cell_hashing_adata.obs_names, - 'negative_hypothesis_probability'] - cell_hashing_adata.obs['singlet_hypothesis_probability'] = results.loc[cell_hashing_adata.obs_names, - 'singlet_hypothesis_probability'] - cell_hashing_adata.obs['doublet_hypothesis_probability'] = results.loc[cell_hashing_adata.obs_names, - 'doublet_hypothesis_probability'] - - cell_hashing_adata.obs['Classification'] = None - cell_hashing_adata.obs.loc[cell_hashing_adata.obs['most_likely_hypothesis'] - == 2, 'Classification'] = 'Doublet' - cell_hashing_adata.obs.loc[cell_hashing_adata.obs['most_likely_hypothesis'] - == 0, 'Classification'] = 'Negative' - all_sings = cell_hashing_adata.obs['most_likely_hypothesis'] == 1 + results.loc[:, "most_likely_hypothesis"] = posterior_dict[ + "most_likely_hypothesis" + ] + results.loc[:, "cluster_feature"] = 0 + results.loc[:, "negative_hypothesis_probability"] = posterior_dict[ + "probs_hypotheses" + ][:, 0] + results.loc[:, "singlet_hypothesis_probability"] = posterior_dict[ + "probs_hypotheses" + ][:, 1] + results.loc[:, "doublet_hypothesis_probability"] = posterior_dict[ + "probs_hypotheses" + ][:, 2] + + cell_hashing_adata.obs["most_likely_hypothesis"] = results.loc[ + cell_hashing_adata.obs_names, "most_likely_hypothesis" + ] + cell_hashing_adata.obs["cluster_feature"] = results.loc[ + cell_hashing_adata.obs_names, "cluster_feature" + ] + cell_hashing_adata.obs["negative_hypothesis_probability"] = results.loc[ + cell_hashing_adata.obs_names, "negative_hypothesis_probability" + ] + cell_hashing_adata.obs["singlet_hypothesis_probability"] = results.loc[ + cell_hashing_adata.obs_names, "singlet_hypothesis_probability" + ] + cell_hashing_adata.obs["doublet_hypothesis_probability"] = results.loc[ + cell_hashing_adata.obs_names, "doublet_hypothesis_probability" + ] + + cell_hashing_adata.obs["Classification"] = None + cell_hashing_adata.obs.loc[ + cell_hashing_adata.obs["most_likely_hypothesis"] == 2, "Classification" + ] = "Doublet" + cell_hashing_adata.obs.loc[ + cell_hashing_adata.obs["most_likely_hypothesis"] == 0, "Classification" + ] = "Negative" + all_sings = cell_hashing_adata.obs["most_likely_hypothesis"] == 1 singlet_sample_index = np.argmax(cell_hashing_adata.X[all_sings], axis=1) - cell_hashing_adata.obs.loc[all_sings, - 'Classification'] = cell_hashing_adata.var_names[singlet_sample_index] + cell_hashing_adata.obs.loc[ + all_sings, "Classification" + ] = cell_hashing_adata.var_names[singlet_sample_index] return cell_hashing_adata if not inplace else None -def plot_qc_checks_cell_hashing(cell_hashing_adata: anndata.AnnData, - alpha: float = .05, - fig_path: str = None): - '''Plot HashSolo demultiplexing results +def plot_qc_checks_cell_hashing( + cell_hashing_adata: anndata.AnnData, alpha: float = 0.05, fig_path: str = None +): + """Plot HashSolo demultiplexing results Parameters ---------- @@ -350,76 +451,119 @@ def plot_qc_checks_cell_hashing(cell_hashing_adata: anndata.AnnData, Path to save figure Returns ------- - ''' + """ import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") import matplotlib.pyplot as plt cell_hashing_demultiplexing = cell_hashing_adata.obs - cell_hashing_demultiplexing['log_counts'] = np.log(np.sum(cell_hashing_adata.X, axis=1)) - number_of_clusters = cell_hashing_demultiplexing['cluster_feature'].drop_duplicates().shape[0] - fig, all_axes = plt.subplots(number_of_clusters, 4, figsize=(40, 10 * number_of_clusters)) + cell_hashing_demultiplexing["log_counts"] = np.log( + np.sum(cell_hashing_adata.X, axis=1) + ) + number_of_clusters = ( + cell_hashing_demultiplexing["cluster_feature"].drop_duplicates().shape[0] + ) + fig, all_axes = plt.subplots( + number_of_clusters, 4, figsize=(40, 10 * number_of_clusters) + ) counter = 0 - for cluster_feature, group in cell_hashing_demultiplexing.groupby('cluster_feature'): + for cluster_feature, group in cell_hashing_demultiplexing.groupby( + "cluster_feature" + ): if number_of_clusters > 1: axes = all_axes[counter] else: axes = all_axes ax = axes[0] - ax.plot(group['log_counts'], group['negative_hypothesis_probability'], 'bo', alpha=alpha) - ax.set_title('Probability of negative hypothesis vs log hashing counts') - ax.set_ylabel('Probability of negative hypothesis') - ax.set_xlabel('Log hashing counts') + ax.plot( + group["log_counts"], + group["negative_hypothesis_probability"], + "bo", + alpha=alpha, + ) + ax.set_title("Probability of negative hypothesis vs log hashing counts") + ax.set_ylabel("Probability of negative hypothesis") + ax.set_xlabel("Log hashing counts") ax = axes[1] - ax.plot(group['log_counts'], group['singlet_hypothesis_probability'], 'bo', alpha=alpha) - ax.set_title('Probability of singlet hypothesis vs log hashing counts') - ax.set_ylabel('Probability of singlet hypothesis') - ax.set_xlabel('Log hashing counts') + ax.plot( + group["log_counts"], + group["singlet_hypothesis_probability"], + "bo", + alpha=alpha, + ) + ax.set_title("Probability of singlet hypothesis vs log hashing counts") + ax.set_ylabel("Probability of singlet hypothesis") + ax.set_xlabel("Log hashing counts") ax = axes[2] - ax.plot(group['log_counts'], group['doublet_hypothesis_probability'], 'bo', alpha=alpha) - ax.set_title('Probability of doublet hypothesis vs log hashing counts') - ax.set_ylabel('Probability of doublet hypothesis') - ax.set_xlabel('Log hashing counts') + ax.plot( + group["log_counts"], + group["doublet_hypothesis_probability"], + "bo", + alpha=alpha, + ) + ax.set_title("Probability of doublet hypothesis vs log hashing counts") + ax.set_ylabel("Probability of doublet hypothesis") + ax.set_xlabel("Log hashing counts") ax = axes[3] - group['Classification'].value_counts().plot.bar(ax=ax) - ax.set_title('Count of each samples classification') + group["Classification"].value_counts().plot.bar(ax=ax) + ax.set_title("Count of each samples classification") counter += 1 plt.show() if fig_path is not None: - fig.savefig(fig_path, dpi=300, format='pdf') + fig.savefig(fig_path, dpi=300, format="pdf") def main(): - usage = 'hashsolo' + usage = "hashsolo" parser = ArgumentParser(usage, formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument(dest='data_file', - help='h5ad file containing cell hashing counts') - parser.add_argument('-j', dest='model_json_file', - default=None, - help='json file to pass optional arguments') - parser.add_argument('-o', dest='out_dir', - default='hashsolo_output', - help='Output directory for results') - parser.add_argument('-c', dest='clustering_data', - default=None, - help='h5ad file with count transcriptional data to\ - perform clustering on') - parser.add_argument('-p', dest='pre_existing_clusters', - default=None, - help='column in cell_hashing_data_file.obs to \ - specifying different cell types or clusters') - parser.add_argument('-q', dest='plot_name', - default='hashing_qc_plots.pdf', - help='name of plot to output') - parser.add_argument('-n', dest='number_of_noise_barcodes', - default=None, - help='Number of barcodes to use to create noise \ - distribution') + parser.add_argument( + dest="data_file", help="h5ad file containing cell hashing counts" + ) + parser.add_argument( + "-j", + dest="model_json_file", + default=None, + help="json file to pass optional arguments", + ) + parser.add_argument( + "-o", + dest="out_dir", + default="hashsolo_output", + help="Output directory for results", + ) + parser.add_argument( + "-c", + dest="clustering_data", + default=None, + help="h5ad file with count transcriptional data to\ + perform clustering on", + ) + parser.add_argument( + "-p", + dest="pre_existing_clusters", + default=None, + help="column in cell_hashing_data_file.obs to \ + specifying different cell types or clusters", + ) + parser.add_argument( + "-q", + dest="plot_name", + default="hashing_qc_plots.pdf", + help="name of plot to output", + ) + parser.add_argument( + "-n", + dest="number_of_noise_barcodes", + default=None, + help="Number of barcodes to use to create noise \ + distribution", + ) args = parser.parse_args() @@ -432,37 +576,41 @@ def main(): params = {} data_file = args.data_file data_ext = os.path.splitext(data_file)[-1] - if data_ext == '.h5ad': + if data_ext == ".h5ad": cell_hashing_adata = anndata.read(data_file) else: - print('Unrecognized file format') + print("Unrecognized file format") if args.clustering_data is not None: clustering_data_file = args.clustering_data clustering_data_ext = os.path.splitext(clustering_data_file)[-1] - if clustering_data_ext == '.h5ad': + if clustering_data_ext == ".h5ad": clustering_data = anndata.read(clustering_data_file) else: - print('Unrecognized file format for clustering data') + print("Unrecognized file format for clustering data") else: clustering_data = None if not os.path.isdir(args.out_dir): os.mkdir(args.out_dir) - hashsolo(cell_hashing_adata, - pre_existing_clusters=args.pre_existing_clusters, - number_of_noise_barcodes=args.number_of_noise_barcodes, - clustering_data=clustering_data, - **params) - cell_hashing_adata.write(os.path.join(args.out_dir, 'hashsoloed.h5ad')) + hashsolo( + cell_hashing_adata, + pre_existing_clusters=args.pre_existing_clusters, + number_of_noise_barcodes=args.number_of_noise_barcodes, + clustering_data=clustering_data, + **params + ) + cell_hashing_adata.write(os.path.join(args.out_dir, "hashsoloed.h5ad")) plot_qc_checks_cell_hashing( - cell_hashing_adata, fig_path=os.path.join(args.out_dir, args.plot_name)) + cell_hashing_adata, fig_path=os.path.join(args.out_dir, args.plot_name) + ) + ############################################################################### # __main__ ############################################################################### -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/solo/solo.py b/solo/solo.py index 7cb50d0..c39678f 100755 --- a/solo/solo.py +++ b/solo/solo.py @@ -1,31 +1,29 @@ #!/usr/bin/env python -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import json import os -import anndata +import umap +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import numpy as np -from sklearn.metrics import roc_auc_score, roc_curve -from scipy.sparse import issparse -from collections import defaultdict +from sklearn.metrics import * +from scipy.special import softmax +from scanpy import read_10x_mtx -import scvi -from scvi.dataset import AnnDatasetFromAnnData, LoomDataset, \ - GeneExpressionDataset, Dataset10X -from scvi.models import Classifier, VAE -from scvi.inference import UnsupervisedTrainer, ClassifierTrainer import torch -import umap +from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from .utils import create_average_doublet, create_summed_doublet, \ - create_multinomial_doublet, make_gene_expression_dataset, \ - knn_smooth_pred_class +import scvi +from scvi.data import read_h5ad, read_loom, setup_anndata +from scvi.model import SCVI +from scvi.external import SOLO + +from .utils import knn_smooth_pred_class -''' +""" solo.py Simulate doublets, train a VAE, and then a classifier on top. -''' +""" ############################################################################### @@ -34,507 +32,435 @@ def main(): - usage = 'solo' + usage = "solo" parser = ArgumentParser(usage, formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument(dest='model_json_file', - help='json file to pass VAE parameters') - parser.add_argument(dest='data_path', - help='path to h5ad, loom or 10x directory containing cell by genes counts') - parser.add_argument('--set-reproducible-seed', dest='reproducible_seed', - default=None, type=int, - help='Reproducible seed, give an int to set seed') - parser.add_argument('-d', dest='doublet_depth', - default=2., type=float, - help='Depth multiplier for a doublet relative to the \ - average of its constituents') - parser.add_argument('-g', dest='gpu', - default=True, action='store_true', - help='Run on GPU') - parser.add_argument('-a', dest='anndata_output', - default=False, action='store_true', - help='output modified anndata object with solo scores \ - Only works for anndata') - parser.add_argument('-o', dest='out_dir', - default='solo_out') - parser.add_argument('-r', dest='doublet_ratio', - default=2., type=float, - help='Ratio of doublets to true \ - cells') - parser.add_argument('-s', dest='seed', - default=None, help='Path to previous solo output \ + parser.add_argument(dest="model_json_file", help="json file to pass VAE parameters") + parser.add_argument( + dest="data_path", help="path to h5ad, loom, or 10x mtx dir cell by genes counts" + ) + parser.add_argument( + "--set-reproducible-seed", + dest="reproducible_seed", + default=None, + type=int, + help="Reproducible seed, give an int to set seed", + ) + parser.add_argument( + "-d", + dest="doublet_depth", + default=2.0, + type=float, + help="Depth multiplier for a doublet relative to the \ + average of its constituents", + ) + parser.add_argument( + "-g", dest="gpu", default=True, action="store_true", help="Run on GPU" + ) + parser.add_argument( + "-a", + dest="anndata_output", + default=False, + action="store_true", + help="output modified anndata object with solo scores \ + Only works for anndata", + ) + parser.add_argument("-o", dest="out_dir", default="solo_out") + parser.add_argument( + "-r", + dest="doublet_ratio", + default=2, + type=int, + help="Ratio of doublets to true \ + cells", + ) + parser.add_argument( + "-s", + dest="seed", + default=None, + help="Path to previous solo output \ directory. Seed VAE models with previously \ trained solo model. Directory structure is assumed to \ be the same as solo output directory structure. \ should at least have a vae.pt a pickled object of \ vae weights and a latent.npy an np.ndarray of the \ - latents of your cells.') - parser.add_argument('-k', dest='known_doublets', - help='Experimentally defined doublets tsv file. \ - Should be a single column of True/False. True \ - indicates the cell is a doublet. No header.', - type=str) - parser.add_argument('-t', dest='doublet_type', help='Please enter \ - multinomial, average, or sum', - default='multinomial', - choices=['multinomial', 'average', 'sum']) - parser.add_argument('-e', dest='expected_number_of_doublets', - help='Experimentally expected number of doublets', - type=int, default=None) - parser.add_argument('-p', dest='plot', - default=False, action='store_true', - help='Plot outputs for solo') - parser.add_argument('-l', dest='normal_logging', - default=False, action='store_true', - help='Logging level set to normal (aka not debug)') - parser.add_argument('--random_size', dest='randomize_doublet_size', - default=False, - action='store_true', - help='Sample depth multipliers from Unif(1, \ - DoubletDepth) \ - to provide a diversity of possible doublet depths.' - ) + latents of your cells.", + ) + parser.add_argument( + "-e", + dest="expected_number_of_doublets", + help="Experimentally expected number of doublets", + type=int, + default=None, + ) + parser.add_argument( + "-p", + dest="plot", + default=False, + action="store_true", + help="Plot outputs for solo", + ) + parser.add_argument( + "-recalibrate_scores", + dest="recalibrate_scores", + default=False, + action="store_true", + help="Recalibrate doublet scores", + ) args = parser.parse_args() - if not args.normal_logging: - scvi._settings.set_verbosity(10) - model_json_file = args.model_json_file data_path = args.data_path if args.gpu and not torch.cuda.is_available(): args.gpu = torch.cuda.is_available() - print('Cuda is not available, switching to cpu running!') + print("Cuda is not available, switching to cpu running!") if not os.path.isdir(args.out_dir): os.mkdir(args.out_dir) if args.reproducible_seed is not None: - torch.manual_seed(args.reproducible_seed) - np.random.seed(args.reproducible_seed) + scvi.settings.seed = args.reproducible_seed + else: + scvi.settings.seed = np.random.randint(10000) + ################################################## # data # read loom/anndata data_ext = os.path.splitext(data_path)[-1] - if data_ext == '.loom': - scvi_data = LoomDataset(data_path) - elif data_ext == '.h5ad': - adata = anndata.read(data_path) - if issparse(adata.X): - adata.X = adata.X.todense() - scvi_data = AnnDatasetFromAnnData(adata) + if data_ext == ".loom": + scvi_data = read_loom(data_path) + elif data_ext == ".h5ad": + scvi_data = read_h5ad(data_path) elif os.path.isdir(data_path): - scvi_data = Dataset10X(save_path=data_path, - measurement_names_column=1, - dense=True) + scvi_data = read_10x_mtx(path=data_path) cell_umi_depth = scvi_data.X.sum(axis=1) fifth, ninetyfifth = np.percentile(cell_umi_depth, [5, 95]) min_cell_umi_depth = np.min(cell_umi_depth) max_cell_umi_depth = np.max(cell_umi_depth) if fifth * 10 < ninetyfifth: - print("""WARNING YOUR DATA HAS A WIDE RANGE OF CELL DEPTHS. - PLEASE MANUALLY REVIEW YOUR DATA""") - print(f"Min cell depth: {min_cell_umi_depth}, Max cell depth: {max_cell_umi_depth}") + print( + """WARNING YOUR DATA HAS A WIDE RANGE OF CELL DEPTHS. + PLEASE MANUALLY REVIEW YOUR DATA""" + ) + print( + f"Min cell depth: {min_cell_umi_depth}, Max cell depth: {max_cell_umi_depth}" + ) else: - msg = f'{data_path} is not a recognized format.\n' - msg += 'must be one of {h5ad, loom, 10x directory}' + msg = f"{data_path} is not a recognized format.\n" + msg += "must be one of {h5ad, loom, 10x mtx dir}" raise TypeError(msg) num_cells, num_genes = scvi_data.X.shape - if args.known_doublets is not None: - print('Removing known doublets for in silico doublet generation') - print('Make sure known doublets are in the same order as your data') - known_doublets = np.loadtxt(args.known_doublets, dtype=str) == 'True' - - assert len(known_doublets) == scvi_data.X.shape[0] - known_doublet_data = make_gene_expression_dataset( - scvi_data.X[known_doublets], - scvi_data.gene_names) - known_doublet_data.labels = np.ones(known_doublet_data.X.shape[0]) - singlet_scvi_data = make_gene_expression_dataset( - scvi_data.X[~known_doublets], - scvi_data.gene_names) - singlet_num_cells, _ = singlet_scvi_data.X.shape - else: - known_doublet_data = None - singlet_num_cells = num_cells - known_doublets = np.zeros(num_cells, dtype=bool) - singlet_scvi_data = scvi_data - singlet_scvi_data.labels = np.zeros(singlet_scvi_data.X.shape[0]) - scvi_data.labels = known_doublets.astype(int) - ################################################## - # parameters - # check for parameters if not os.path.exists(model_json_file): - raise FileNotFoundError(f'{model_json_file} does not exist.') + raise FileNotFoundError(f"{model_json_file} does not exist.") # read parameters - with open(model_json_file, 'r') as model_json_open: + with open(model_json_file, "r") as model_json_open: params = json.load(model_json_open) # set VAE params vae_params = {} - for par in ['n_hidden', 'n_latent', 'n_layers', 'dropout_rate', - 'ignore_batch']: + for par in ["n_hidden", "n_latent", "n_layers", "dropout_rate", "ignore_batch"]: if par in params: vae_params[par] = params[par] - vae_params['n_batch'] = 0 if params.get( - 'ignore_batch', False) else scvi_data.n_batches # training parameters - batch_size = params.get('batch_size', 128) - valid_pct = params.get('valid_pct', 0.1) - learning_rate = params.get('learning_rate', 1e-3) - stopping_params = {'patience': params.get('patience', 10), 'threshold': 0} + batch_key = params.get("batch_key", None) + batch_size = params.get("batch_size", 128) + valid_pct = params.get("valid_pct", 0.1) + learning_rate = params.get("learning_rate", 1e-3) + stopping_params = {"patience": params.get("patience", 40), "min_delta": 0} # protect against single example batch while num_cells % batch_size == 1: - batch_size = int(np.round(1.25*batch_size)) - print('Increasing batch_size to %d to avoid single example batch.' % batch_size) + batch_size = int(np.round(1.25 * batch_size)) + print("Increasing batch_size to %d to avoid single example batch." % batch_size) + scvi.settings.batch_size = batch_size ################################################## - # VAE - - vae = VAE(n_input=singlet_scvi_data.nb_genes, n_labels=2, - reconstruction_loss='nb', - log_variational=True, **vae_params) + # SCVI + setup_anndata(scvi_data, batch_key=batch_key) + vae = SCVI( + scvi_data, + gene_likelihood="nb", + log_variational=True, + **vae_params, + use_observed_lib_size=False, + ) if args.seed: - if args.gpu: - device = torch.device('cuda') - vae.load_state_dict(torch.load(os.path.join(args.seed, 'vae.pt'))) - vae.to(device) - else: - map_loc = 'cpu' - vae.load_state_dict(torch.load(os.path.join(args.seed, 'vae.pt'), - map_location=map_loc)) - - # save latent representation - utrainer = \ - UnsupervisedTrainer(vae, singlet_scvi_data, - train_size=(1. - valid_pct), - frequency=2, - metrics_to_monitor=['reconstruction_error'], - use_cuda=args.gpu, - early_stopping_kwargs=stopping_params, - batch_size=batch_size) - - full_posterior = utrainer.create_posterior( - utrainer.model, - singlet_scvi_data, - indices=np.arange(len(singlet_scvi_data))) - latent, _, _ = full_posterior.sequential(batch_size).get_latent() - np.save(os.path.join(args.out_dir, 'latent.npy'), - latent.astype('float32')) - + vae = vae.load(os.path.join(args.seed, "vae"), use_gpu=args.gpu) else: - stopping_params['early_stopping_metric'] = 'reconstruction_error' - stopping_params['save_best_state_metric'] = 'reconstruction_error' - - # initialize unsupervised trainer - utrainer = \ - UnsupervisedTrainer(vae, singlet_scvi_data, - train_size=(1. - valid_pct), - frequency=2, - metrics_to_monitor=['reconstruction_error'], - use_cuda=args.gpu, - early_stopping_kwargs=stopping_params, - batch_size=batch_size) - utrainer.history['reconstruction_error_test_set'].append(0) - # initial epoch - utrainer.train(n_epochs=2000, lr=learning_rate) - - # drop learning rate and continue - utrainer.early_stopping.wait = 0 - utrainer.train(n_epochs=500, lr=0.5 * learning_rate) - + scvi_callbacks = [] + scvi_callbacks += [ + EarlyStopping( + monitor="reconstruction_loss_validation", mode="min", **stopping_params + ) + ] + plan_kwargs = { + "reduce_lr_on_plateau": True, + "lr_factor": 0.1, + "lr": 1e-2, + "lr_patience": 10, + "lr_threshold": 0, + "lr_min": 1e-4, + "lr_scheduler_metric": "reconstruction_loss_validation", + } + + vae.train( + max_epochs=2000, + validation_size=valid_pct, + check_val_every_n_epoch=1, + plan_kwargs=plan_kwargs, + callbacks=scvi_callbacks, + ) # save VAE - torch.save(vae.state_dict(), os.path.join(args.out_dir, 'vae.pt')) + vae.save(os.path.join(args.out_dir, "vae")) - # save latent representation - full_posterior = utrainer.create_posterior( - utrainer.model, - singlet_scvi_data, - indices=np.arange(len(singlet_scvi_data))) - latent, _, _ = full_posterior.sequential(batch_size).get_latent() - np.save(os.path.join(args.out_dir, 'latent.npy'), - latent.astype('float32')) - - ################################################## - # simulate doublets - - non_zero_indexes = np.where(singlet_scvi_data.X > 0) - cells = non_zero_indexes[0] - genes = non_zero_indexes[1] - cells_ids = defaultdict(list) - for cell_id, gene in zip(cells, genes): - cells_ids[cell_id].append(gene) - - # choose doublets function type - if args.doublet_type == 'average': - doublet_function = create_average_doublet - elif args.doublet_type == 'sum': - doublet_function = create_summed_doublet - else: - doublet_function = create_multinomial_doublet - - cell_depths = singlet_scvi_data.X.sum(axis=1) - num_doublets = int(args.doublet_ratio * singlet_num_cells) - if known_doublet_data is not None: - num_doublets -= known_doublet_data.X.shape[0] - # make sure we are making a non negative amount of doublets - assert num_doublets >= 0 - - in_silico_doublets = np.zeros((num_doublets, num_genes), dtype='float32') - # for desired # doublets - for di in range(num_doublets): - # sample two cells - i, j = np.random.choice(singlet_num_cells, size=2) - - # generate doublets - in_silico_doublets[di, :] = \ - doublet_function(singlet_scvi_data.X, i, j, - doublet_depth=args.doublet_depth, - cell_depths=cell_depths, cells_ids=cells_ids, - randomize_doublet_size=args.randomize_doublet_size) - - # merge datasets - # we can maybe up sample the known doublets - # concatentate - classifier_data = GeneExpressionDataset() - classifier_data.populate_from_data( - X=np.vstack([scvi_data.X, - in_silico_doublets]), - labels=np.hstack([np.ravel(scvi_data.labels), - np.ones(in_silico_doublets.shape[0])]), - remap_attributes=False) - - assert(len(np.unique(classifier_data.labels.flatten())) == 2) + latent = vae.get_latent_representation() + # save latent representation + np.save(os.path.join(args.out_dir, "latent.npy"), latent.astype("float32")) ################################################## # classifier # model - classifier = Classifier(n_input=(vae.n_latent + 1), - n_hidden=params['cl_hidden'], - n_layers=params['cl_layers'], n_labels=2, - dropout_rate=params['dropout_rate']) - - # trainer - stopping_params['early_stopping_metric'] = 'accuracy' - stopping_params['save_best_state_metric'] = 'accuracy' - strainer = ClassifierTrainer(classifier, classifier_data, - train_size=(1. - valid_pct), - frequency=2, metrics_to_monitor=['accuracy'], - use_cuda=args.gpu, - sampling_model=vae, sampling_zl=True, - early_stopping_kwargs=stopping_params, - batch_size=batch_size) - - # initial - strainer.train(n_epochs=1000, lr=learning_rate) - - # drop learning rate and continue - strainer.early_stopping.wait = 0 - strainer.train(n_epochs=300, lr=0.1 * learning_rate) - torch.save(classifier.state_dict(), os.path.join(args.out_dir, 'classifier.pt')) - - - ################################################## - # post-processing - # use logits for predictions for better results - logits_classifier = Classifier(n_input=(vae.n_latent + 1), - n_hidden=params['cl_hidden'], - n_layers=params['cl_layers'], n_labels=2, - dropout_rate=params['dropout_rate'], - logits=True) - logits_classifier.load_state_dict(classifier.state_dict()) - - # using logits leads to better performance in for ranking - logits_strainer = ClassifierTrainer(logits_classifier, classifier_data, - train_size=(1. - valid_pct), - frequency=2, - metrics_to_monitor=['accuracy'], - use_cuda=args.gpu, - sampling_model=vae, sampling_zl=True, - early_stopping_kwargs=stopping_params, - batch_size=batch_size) - - # models evaluation mode - vae.eval() - classifier.eval() - logits_classifier.eval() - - print('Train accuracy: %.4f' % strainer.train_set.accuracy()) - print('Test accuracy: %.4f' % strainer.test_set.accuracy()) - - # compute predictions manually - # output logits - train_y, train_score = strainer.train_set.compute_predictions(soft=True) - test_y, test_score = strainer.test_set.compute_predictions(soft=True) - # train_y == true label - # train_score[:, 0] == singlet score; train_score[:, 1] == doublet score - train_score = train_score[:, 1] - train_y = train_y.astype('bool') - test_score = test_score[:, 1] - test_y = test_y.astype('bool') - - train_auroc = roc_auc_score(train_y, train_score) - test_auroc = roc_auc_score(test_y, test_score) - - print('Train AUROC: %.4f' % train_auroc) - print('Test AUROC: %.4f' % test_auroc) - - train_fpr, train_tpr, train_t = roc_curve(train_y, train_score) - test_fpr, test_tpr, test_t = roc_curve(test_y, test_score) - train_t = np.minimum(train_t, 1 + 1e-9) - test_t = np.minimum(test_t, 1 + 1e-9) - - train_acc = np.zeros(len(train_t)) - for i in range(len(train_t)): - train_acc[i] = np.mean(train_y == (train_score > train_t[i])) - test_acc = np.zeros(len(test_t)) - for i in range(len(test_t)): - test_acc[i] = np.mean(test_y == (test_score > test_t[i])) + # todo add doublet ratio + solo = SOLO.from_scvi_model(vae, doublet_ratio=args.doublet_ratio) + solo.train( + 2000, + lr=learning_rate, + train_size=0.9, + check_val_every_n_epoch=1, + early_stopping_patience=30, + ) + solo.train( + 2000, + lr=learning_rate * 0.1, + train_size=0.9, + check_val_every_n_epoch=1, + early_stopping_patience=30, + callbacks=[], + ) + solo.save(os.path.join(args.out_dir, "classifier")) + + logit_predictions = solo.predict(include_simulated_doublets=True) + + is_doublet_known = solo.adata.obs._solo_doub_sim == "doublet" + is_doublet_pred = logit_predictions.idxmin(axis=1) == "singlet" + + validation_is_doublet_known = is_doublet_known[solo.validation_indices] + validation_is_doublet_pred = is_doublet_pred[solo.validation_indices] + training_is_doublet_known = is_doublet_known[solo.train_indices] + training_is_doublet_pred = is_doublet_pred[solo.train_indices] + + valid_as = accuracy_score(validation_is_doublet_known, validation_is_doublet_pred) + valid_roc = roc_auc_score(validation_is_doublet_known, validation_is_doublet_pred) + valid_ap = average_precision_score( + validation_is_doublet_known, validation_is_doublet_pred + ) + + train_as = accuracy_score(training_is_doublet_known, training_is_doublet_pred) + train_roc = roc_auc_score(training_is_doublet_known, training_is_doublet_pred) + train_ap = average_precision_score( + training_is_doublet_known, training_is_doublet_pred + ) + + print(f"Training results") + print(f"AUROC: {train_roc}, Accuracy: {train_as}, Average precision: {train_ap}") + + print(f"Validation results") + print(f"AUROC: {valid_roc}, Accuracy: {valid_as}, Average precision: {valid_ap}") # write predictions # softmax predictions - order_y, order_score = strainer.compute_predictions(soft=True) - _, order_pred = strainer.compute_predictions() - doublet_score = order_score[:, 1] - np.save(os.path.join(args.out_dir, 'no_updates_softmax_scores.npy'), doublet_score[:num_cells]) - np.savetxt(os.path.join(args.out_dir, 'no_updates_softmax_scores.csv'), doublet_score[:num_cells], delimiter=",") - - np.save(os.path.join(args.out_dir, 'no_updates_softmax_scores_sim.npy'), doublet_score[num_cells:]) + softmax_predictions = softmax(logit_predictions, axis=1) + doublet_score = softmax_predictions.loc[:, "doublet"] + + np.save( + os.path.join(args.out_dir, "no_updates_softmax_scores.npy"), + doublet_score[:num_cells], + ) + np.savetxt( + os.path.join(args.out_dir, "no_updates_softmax_scores.csv"), + doublet_score[:num_cells], + delimiter=",", + ) + np.save( + os.path.join(args.out_dir, "no_updates_softmax_scores_sim.npy"), + doublet_score[num_cells:], + ) # logit predictions - logit_y, logit_score = logits_strainer.compute_predictions(soft=True) - logit_doublet_score = logit_score[:, 1] - np.save(os.path.join(args.out_dir, 'logit_scores.npy'), logit_doublet_score[:num_cells]) - np.savetxt(os.path.join(args.out_dir, 'logit_scores.csv'), logit_doublet_score[:num_cells], delimiter=",") - - np.save(os.path.join(args.out_dir, 'logit_scores_sim.npy'), logit_doublet_score[num_cells:]) - + logit_doublet_score = logit_predictions.loc[:, "doublet"] + np.save( + os.path.join(args.out_dir, "logit_scores.npy"), logit_doublet_score[:num_cells] + ) + np.savetxt( + os.path.join(args.out_dir, "logit_scores.csv"), + logit_doublet_score[:num_cells], + delimiter=",", + ) + np.save( + os.path.join(args.out_dir, "logit_scores_sim.npy"), + logit_doublet_score[num_cells:], + ) # update threshold as a function of Solo's estimate of the number of # doublets # essentially a log odds update # TODO put in a function + # currently overshrinking softmaxes diff = np.inf counter_update = 0 solo_scores = doublet_score[:num_cells] logit_scores = logit_doublet_score[:num_cells] - d_s = (args.doublet_ratio / (args.doublet_ratio + 1)) - while (diff > .01) | (counter_update < 5): + d_s = args.doublet_ratio / (args.doublet_ratio + 1) + if args.recalibrate_scores: + while (diff > 0.01) | (counter_update < 5): - # calculate log odds calibration for logits - d_o = np.mean(solo_scores) - c = np.log(d_o/(1-d_o)) - np.log(d_s/(1-d_s)) + # calculate log odds calibration for logits + d_o = np.mean(solo_scores) + c = np.log(d_o / (1 - d_o)) - np.log(d_s / (1 - d_s)) - # update solo scores - solo_scores = 1 / (1+np.exp(-(logit_scores + c))) + # update solo scores + solo_scores = 1 / (1 + np.exp(-(logit_scores + c))) - # update while conditions - diff = np.abs(d_o - np.mean(solo_scores)) - counter_update += 1 - - np.save(os.path.join(args.out_dir, 'softmax_scores.npy'), - solo_scores) - np.savetxt(os.path.join(args.out_dir, 'softmax_scores.csv'), - solo_scores, delimiter=",") + # update while conditions + diff = np.abs(d_o - np.mean(solo_scores)) + counter_update += 1 + np.save(os.path.join(args.out_dir, "softmax_scores.npy"), solo_scores) + np.savetxt( + os.path.join(args.out_dir, "softmax_scores.csv"), solo_scores, delimiter="," + ) if args.expected_number_of_doublets is not None: k = len(solo_scores) - args.expected_number_of_doublets - if args.expected_number_of_doublets / len(solo_scores) > .5: - print('''Make sure you actually expect more than half your cells + if args.expected_number_of_doublets / len(solo_scores) > 0.5: + print( + """Make sure you actually expect more than half your cells to be doublets. If not change your - -e parameter value''') + -e parameter value""" + ) assert k > 0 idx = np.argpartition(solo_scores, k) threshold = np.max(solo_scores[idx[:k]]) is_solo_doublet = solo_scores > threshold else: - is_solo_doublet = solo_scores > .5 - - is_doublet = known_doublets - new_doublets_idx = np.where(~(is_doublet) & is_solo_doublet[:num_cells])[0] - is_doublet[new_doublets_idx] = True - - np.save(os.path.join(args.out_dir, 'is_doublet.npy'), is_doublet[:num_cells]) - np.savetxt(os.path.join(args.out_dir, 'is_doublet.csv'), is_doublet[:num_cells], delimiter=",") - - np.save(os.path.join(args.out_dir, 'is_doublet_sim.npy'), is_doublet[num_cells:]) - - np.save(os.path.join(args.out_dir, 'preds.npy'), order_pred[:num_cells]) - np.savetxt(os.path.join(args.out_dir, 'preds.csv'), order_pred[:num_cells], delimiter=",") - - np.save(os.path.join(args.out_dir, 'preds_sim.npy'), order_pred[num_cells:]) - - smoothed_preds = knn_smooth_pred_class(X=latent, pred_class=is_doublet[:num_cells]) - np.save(os.path.join(args.out_dir, 'smoothed_preds.npy'), smoothed_preds) - - if args.anndata_output and data_ext == '.h5ad': - adata.obs['is_doublet'] = is_doublet[:num_cells] - adata.obs['logit_scores'] = logit_doublet_score[:num_cells] - adata.obs['softmax_scores'] = doublet_score[:num_cells] - adata.write(os.path.join(args.out_dir, "soloed.h5ad")) + is_solo_doublet = solo_scores > 0.5 + + np.save(os.path.join(args.out_dir, "is_doublet.npy"), is_solo_doublet[:num_cells]) + np.savetxt( + os.path.join(args.out_dir, "is_doublet.csv"), + is_solo_doublet[:num_cells], + delimiter=",", + ) + + np.save( + os.path.join(args.out_dir, "is_doublet_sim.npy"), is_solo_doublet[num_cells:] + ) + + np.save(os.path.join(args.out_dir, "preds.npy"), is_doublet_pred[:num_cells]) + np.savetxt( + os.path.join(args.out_dir, "preds.csv"), + is_doublet_pred[:num_cells], + delimiter=",", + ) + + smoothed_preds = knn_smooth_pred_class( + X=latent, pred_class=is_doublet_pred[:num_cells] + ) + np.save(os.path.join(args.out_dir, "smoothed_preds.npy"), smoothed_preds) + + if args.anndata_output and data_ext == ".h5ad": + scvi_data.obs["is_doublet"] = is_solo_doublet[:num_cells].values.astype(bool) + scvi_data.obs["logit_scores"] = logit_doublet_score[:num_cells].values.astype( + float + ) + scvi_data.obs["softmax_scores"] = solo_scores[:num_cells].values.astype(float) + scvi_data.write(os.path.join(args.out_dir, "soloed.h5ad")) if args.plot: import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") import matplotlib.pyplot as plt import seaborn as sns + + train_solo_scores = doublet_score[solo.train_indices] + validation_solo_scores = doublet_score[solo.validation_indices] + + train_fpr, train_tpr, _ = roc_curve( + training_is_doublet_known, train_solo_scores + ) + val_fpr, val_tpr, _ = roc_curve( + validation_is_doublet_known, validation_solo_scores + ) + # plot ROC plt.figure() - plt.plot(train_fpr, train_tpr, label='Train') - plt.plot(test_fpr, test_tpr, label='Test') - plt.gca().set_xlabel('False positive rate') - plt.gca().set_ylabel('True positive rate') + plt.plot(train_fpr, train_tpr, label="Train") + plt.plot(val_fpr, val_tpr, label="Validation") + plt.gca().set_xlabel("False positive rate") + plt.gca().set_ylabel("True positive rate") plt.legend() - plt.savefig(os.path.join(args.out_dir, 'roc.pdf')) + plt.savefig(os.path.join(args.out_dir, "roc.pdf")) plt.close() + train_precision, train_recall, _ = precision_recall_curve( + training_is_doublet_known, train_solo_scores + ) + val_precision, val_recall, _ = precision_recall_curve( + validation_is_doublet_known, validation_solo_scores + ) # plot accuracy plt.figure() - plt.plot(train_t, train_acc, label='Train') - plt.plot(test_t, test_acc, label='Test') - plt.axvline(0.5, color='black', linestyle='--') - plt.gca().set_xlabel('Threshold') - plt.gca().set_ylabel('Accuracy') + plt.plot(train_recall, train_precision, label="Train") + plt.plot(val_recall, val_precision, label="Validation") + plt.gca().set_xlabel("Recall") + plt.gca().set_ylabel("pytPrecision") plt.legend() - plt.savefig(os.path.join(args.out_dir, 'accuracy.pdf')) + plt.savefig(os.path.join(args.out_dir, "precision_recall.pdf")) plt.close() # plot distributions + obs_indices = solo.validation_indices[solo.validation_indices < num_cells] + sim_indices = solo.validation_indices[solo.validation_indices > num_cells] + plt.figure() - sns.distplot(test_score[test_y], label='Simulated') - sns.distplot(test_score[~test_y], label='Observed') + sns.displot(doublet_score[sim_indices], label="Simulated") + sns.displot(doublet_score[obs_indices], label="Observed") plt.legend() - plt.savefig(os.path.join(args.out_dir, 'train_v_test_dist.pdf')) + plt.savefig(os.path.join(args.out_dir, "sim_vs_obs_dist.pdf")) plt.close() plt.figure() - sns.distplot(doublet_score[:num_cells], label='Observed') + sns.distplot(solo_scores[:num_cells], label="Observed (transformed)") plt.legend() - plt.savefig(os.path.join(args.out_dir, 'real_cells_dist.pdf')) + plt.savefig(os.path.join(args.out_dir, "real_cells_dist.pdf")) plt.close() scvi_umap = umap.UMAP(n_neighbors=16).fit_transform(latent) fig, ax = plt.subplots(1, 1, figsize=(10, 10)) - ax.scatter(scvi_umap[:, 0], scvi_umap[:, 1], - c=doublet_score[:num_cells], s=8, cmap="GnBu") + ax.scatter( + scvi_umap[:, 0], + scvi_umap[:, 1], + c=doublet_score[:num_cells], + s=8, + cmap="GnBu", + ) ax.set_xlabel("UMAP 1") ax.set_ylabel("UMAP 2") - ax.set_xticks([], []) - ax.set_yticks([], []) - fig.savefig(os.path.join(args.out_dir, 'umap_solo_scores.pdf')) + fig.savefig(os.path.join(args.out_dir, "umap_solo_scores.pdf")) + ############################################################################### # __main__ ############################################################################### -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/solo/utils.py b/solo/utils.py index cb3c36e..9352561 100644 --- a/solo/utils.py +++ b/solo/utils.py @@ -1,15 +1,16 @@ import numpy as np -from scvi.dataset import GeneExpressionDataset from scipy.stats import multinomial from sklearn.neighbors import NearestNeighbors -def knn_smooth_pred_class(X: np.ndarray, - pred_class: np.ndarray, - grouping: np.ndarray = None, - k: int = 15,) -> np.ndarray: - ''' +def knn_smooth_pred_class( + X: np.ndarray, + pred_class: np.ndarray, + grouping: np.ndarray = None, + k: int = 15, +) -> np.ndarray: + """ Smooths class predictions by taking the modal class from each cell's nearest neighbors. Parameters @@ -40,7 +41,7 @@ def knn_smooth_pred_class(X: np.ndarray, By using a simple kNN smoothing heuristic, we can leverage neighborhood information to improve classification performance, smoothing out cells that have an outlier prediction relative to their local neighborhood. - ''' + """ if grouping is None: # do not use a grouping to restrict local neighborhood # associations, create a universal pseudogroup `0`. @@ -49,7 +50,7 @@ def knn_smooth_pred_class(X: np.ndarray, smooth_pred_class = np.zeros_like(pred_class) for group in np.unique(grouping): # identify only cells in the relevant group - group_idx = np.where(grouping == group)[0].astype('int') + group_idx = np.where(grouping == group)[0].astype("int") X_group = X[grouping == group, :] # if there are < k cells in the group, change `k` to the # group size @@ -58,7 +59,9 @@ def knn_smooth_pred_class(X: np.ndarray, else: k_use = k # compute a nearest neighbor graph and identify kNN - nns = NearestNeighbors(n_neighbors=k_use,).fit(X_group) + nns = NearestNeighbors( + n_neighbors=k_use, + ).fit(X_group) dist, idx = nns.kneighbors(X_group) # for each cell in the group, assign a class as @@ -69,118 +72,3 @@ def knn_smooth_pred_class(X: np.ndarray, maj_class = uniq_classes[int(np.argmax(counts))] smooth_pred_class[group_idx[i]] = maj_class return smooth_pred_class - - -def create_average_doublet(X: np.ndarray, - i: int, - j: int, **kwargs): - '''make an average combination of 2 cells - - Parameters - ---------- - X : np.array - cell by genes matrix - i : int, - randomly chosen ith cell - j : int, - randomly chosen jth cell - Returns - ------- - float64 - average expression vector of two cells - ''' - return (X[i, :] + X[j, :]).astype('float64') / 2 - - -def create_summed_doublet(X: np.ndarray, - i: int, - j: int, **kwargs): - '''make a sum combination of 2 cells - - Parameters - ---------- - X : np.array - cell by genes matrix - i : int, - randomly chosen ith cell - j : int, - randomly chosen jth cell - Returns - ------- - float64 - summed expression vector of two cells - ''' - return (X[i, :] + X[j, :]).astype('float64') - - -def create_multinomial_doublet(X: np.ndarray, - i: int, - j: int, **kwargs): - '''make a multinomial combination of 2 cells - - Parameters - ---------- - X : np.array - cell by genes matrix - i : int, - randomly chosen ith cell - j : int, - randomly chosen jth cell - kwargs : dict, - dict with doublet_depth, cell_depths and cells_ids as keys - doublet_depth is an int - cell_depths is an list of all cells total UMI counts as ints - cell_ids list of lists with genes with counts for each cell - Returns - ------- - float64 - multinomial expression vector of two cells - ''' - doublet_depth = kwargs["doublet_depth"] - cell_depths = kwargs["cell_depths"] - cells_ids = kwargs["cells_ids"] - randomize_doublet_size = kwargs["randomize_doublet_size"] - - # add their counts - dp = (X[i, :] - + X[j, :]).astype('float64') - dp = np.ravel(dp) - non_zero_indexes = np.unique(cells_ids[i] + cells_ids[j]) - # a huge hack caused by - # https://github.com/numpy/numpy/issues/8317 - # fun fun fun https://stackoverflow.com/questions/23257587/how-can-i-avoid-value-errors-when-using-numpy-random-multinomial - # okay with this hack because affects pro - dp = dp[non_zero_indexes] - # normalize - dp /= dp.sum() - if randomize_doublet_size: - scale_factor = np.random.uniform(1., doublet_depth) - else: - scale_factor = doublet_depth - # choose depth - dd = int(scale_factor * (cell_depths[i] + cell_depths[j]) / 2) - - # sample counts from multinomial - non_zero_probs = multinomial.rvs(n=dd, p=dp) - probs = np.zeros(X.shape[1]) - probs[non_zero_indexes] = non_zero_probs - return probs - - -def make_gene_expression_dataset(data: np.ndarray, gene_names: np.ndarray): - '''make an scVI GeneExpressionDataset - - Parameters - ---------- - data : np.array - cell by genes matrix - gene_names : np.array, - string array with gene names - Returns - ------- - ge_data : GeneExpressionDataset - scVI GeneExpressionDataset for scVI processing - ''' - ge_data = GeneExpressionDataset() - ge_data.populate_from_data(X=data, gene_names=gene_names) - return ge_data diff --git a/testdata/calculate_performance.py b/testdata/calculate_performance.py index 2373182..0bc90dc 100644 --- a/testdata/calculate_performance.py +++ b/testdata/calculate_performance.py @@ -1,26 +1,30 @@ #!/usr/bin/env python -import anndata +import anndata import numpy as np from sklearn.metrics import average_precision_score, roc_auc_score from scipy.stats import mannwhitneyu import matplotlib.pyplot as plt -import datetime +import datetime import pandas as pd from glob import glob -''' + +""" calculate performance -''' +""" ############################################################################### # main ############################################################################### -experiment_name_to_dataset = {'pbmc': '2c.h5ad', - 'kidney': 'gene_ad_filtered_PoolB4FACs_L4_Rep1.h5ad'} +experiment_name_to_dataset = { + "pbmc": "2c.h5ad", + "kidney": "gene_ad_filtered_PoolB4FACs_L4_Rep1.h5ad", +} + def main(): - for result in glob('results_*/softmax_scores.npy'): + for result in glob("results_*/softmax_scores.npy"): experiment_name = result.split("/")[0].split("_")[1] experiment_number = result.split("/")[0].split("_")[2] scores = np.load(result) @@ -29,43 +33,57 @@ def main(): apr = average_precision_score(true_labels, scores) auc = roc_auc_score(true_labels, scores) time = datetime.datetime.now().strftime("%Y-%m-%d %H") - with open('tracking_performance.csv', 'a') as file: - file.write(f'{time},{experiment_name},{experiment_number},{apr},{auc}\n') + with open("tracking_performance.csv", "a") as file: + file.write(f"{time},{experiment_name},{experiment_number},{apr},{auc}\n") - performance_tracking = pd.read_csv('tracking_performance.csv') - performance_tracking['date (dt)'] = pd.to_datetime(performance_tracking['date'], format="%Y-%m-%d %H") - for experiment_name, group in performance_tracking.groupby('experiment_name'): - fig, axes = plt.subplots(2, 1, figsize=(10,20)) + performance_tracking = pd.read_csv("tracking_performance.csv") + performance_tracking["date (dt)"] = pd.to_datetime( + performance_tracking["date"], format="%Y-%m-%d %H" + ) + for experiment_name, group in performance_tracking.groupby("experiment_name"): + fig, axes = plt.subplots(2, 1, figsize=(10, 20)) ax = axes[0] - ax.plot(group['date'], group['average_precision'], '.') - ax.set_xlabel('date') - ax.set_ylabel('average precision') + ax.plot(group["date"], group["average_precision"], ".") + ax.set_xlabel("date") + ax.set_ylabel("average precision") ax = axes[1] - ax.plot(group['date'], group['AUROC'], '.') - ax.set_xlabel('date') - ax.set_ylabel('AUROC') - fig.savefig(f'{experiment_name}_performance_tracking.png') - second_to_last, most_recent = group['date (dt)'].drop_duplicates().sort_values()[-2:] - second_to_last_df = group[group['date (dt)'] == second_to_last] - most_recent_df = group[group['date (dt)'] == most_recent] - for metric in ['AUROC', 'average_precision']: - mean_change = most_recent_df[metric].mean() - second_to_last_df[metric].mean() - pvalue = mannwhitneyu(most_recent_df[metric], second_to_last_df[metric]).pvalue - print(f'Mean {metric} has changed by for {experiment_name}: {mean_change}') - print(f'P value for metric change {metric} in experiment {experiment_name}: {pvalue}') - if mean_change < 0 and pvalue < .05: - for x in range(0,5): - print('WARNING!') - print(f'WARNING {metric} HAS GOTTEN SIGNIFICANTLY WORSE for {experiment_name}!') - if mean_change > 0 and pvalue < .05: - for x in range(0,5): - print('NICE JOB!') - print(f'NICE JOB {metric} HAS GOTTEN SIGNIFICANTLY BETTER for {experiment_name}!') + ax.plot(group["date"], group["AUROC"], ".") + ax.set_xlabel("date") + ax.set_ylabel("AUROC") + fig.savefig(f"{experiment_name}_performance_tracking.png") + second_to_last, most_recent = ( + group["date (dt)"].drop_duplicates().sort_values()[-2:] + ) + second_to_last_df = group[group["date (dt)"] == second_to_last] + most_recent_df = group[group["date (dt)"] == most_recent] + for metric in ["AUROC", "average_precision"]: + mean_change = ( + most_recent_df[metric].mean() - second_to_last_df[metric].mean() + ) + pvalue = mannwhitneyu( + most_recent_df[metric], second_to_last_df[metric] + ).pvalue + print(f"Mean {metric} has changed by for {experiment_name}: {mean_change}") + print( + f"P value for metric change {metric} in experiment {experiment_name}: {pvalue}" + ) + if mean_change < 0 and pvalue < 0.05: + for x in range(0, 5): + print("WARNING!") + print( + f"WARNING {metric} HAS GOTTEN SIGNIFICANTLY WORSE for {experiment_name}!" + ) + if mean_change > 0 and pvalue < 0.05: + for x in range(0, 5): + print("NICE JOB!") + print( + f"NICE JOB {metric} HAS GOTTEN SIGNIFICANTLY BETTER for {experiment_name}!" + ) + - ############################################################################### # __main__ ############################################################################### -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/testdata/kidney_performance_tracking.png b/testdata/kidney_performance_tracking.png index 1510371..ced32f3 100644 Binary files a/testdata/kidney_performance_tracking.png and b/testdata/kidney_performance_tracking.png differ diff --git a/testdata/pbmc_performance_tracking.png b/testdata/pbmc_performance_tracking.png index 50db371..4ad97d7 100644 Binary files a/testdata/pbmc_performance_tracking.png and b/testdata/pbmc_performance_tracking.png differ diff --git a/testdata/performance_test_kidney_PoolB4FACs_L4_Rep1.sh b/testdata/performance_test_kidney_PoolB4FACs_L4_Rep1.sh index 256d90f..1931b9b 100644 --- a/testdata/performance_test_kidney_PoolB4FACs_L4_Rep1.sh +++ b/testdata/performance_test_kidney_PoolB4FACs_L4_Rep1.sh @@ -13,5 +13,5 @@ echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID echo 'kidney' source activate solo-sc -solo -g -r 2 -d 2 -t sum -o results_kidney_"$SLURM_ARRAY_TASK_ID" ../solo_params_example.json gene_ad_filtered_PoolB4FACs_L4_Rep1.h5ad +solo -p -a -g -r 2 --set-reproducible-seed "$SLURM_ARRAY_TASK_ID" -o results_kidney_"$SLURM_ARRAY_TASK_ID" ../solo_params_example.json gene_ad_filtered_PoolB4FACs_L4_Rep1.h5ad diff --git a/testdata/performance_test_pbmc_2c.sh b/testdata/performance_test_pbmc_2c.sh index cefae1a..e077c6c 100644 --- a/testdata/performance_test_pbmc_2c.sh +++ b/testdata/performance_test_pbmc_2c.sh @@ -13,5 +13,5 @@ echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID echo 'pbmc' source activate solo-sc -solo -g -r 2 -d 2 -t sum -o results_pbmc_"$SLURM_ARRAY_TASK_ID" ../solo_params_example.json 2c.h5ad +solo -p -a -g -r 2 --set-reproducible-seed "$SLURM_ARRAY_TASK_ID" -o results_pbmc_"$SLURM_ARRAY_TASK_ID" ../solo_params_example.json 2c.h5ad \ No newline at end of file diff --git a/testdata/tracking_performance.csv b/testdata/tracking_performance.csv index f3c509d..3915908 100644 --- a/testdata/tracking_performance.csv +++ b/testdata/tracking_performance.csv @@ -13,3 +13,31 @@ date,experiment_name,experiment_number,average_precision,AUROC 2021-02-11 11,kidney,3,0.6467093639812479,0.7573174534495049 2021-02-11 11,kidney,4,0.6586798480352564,0.7664625647110882 2021-02-11 11,pbmc,5,0.6478040939493424,0.9225424292745276 +2021-02-18 14,kidney,1,0.6462699623991646,0.7627879613475528 +2021-02-18 14,pbmc,4,0.663326842485893,0.9165240840419175 +2021-02-18 14,kidney,5,0.6480422575749873,0.7652822022156713 +2021-02-18 14,kidney,2,0.648134236408012,0.7684026253229614 +2021-02-18 14,pbmc,3,0.6697257411365286,0.9161077482696734 +2021-02-18 14,pbmc,1,0.6660700438231723,0.916962444444143 +2021-02-18 14,pbmc,6,0.6515373213882809,0.9138401321031976 +2021-02-18 14,pbmc,2,0.6693514926787216,0.9162711036660809 +2021-02-18 14,kidney,3,0.6492802667687769,0.7649097761990968 +2021-02-18 14,kidney,4,0.6472856211876892,0.7633028207786451 +2021-02-18 14,pbmc,5,0.6710936786375143,0.9192144568023591 +2021-02-19 16,pbmc,4,0.6723676680997775,0.9177868779532147 +2021-02-19 16,pbmc,3,0.6651878244617202,0.9157017827302478 +2021-02-19 16,pbmc,1,0.6658168009292,0.9160364650338042 +2021-02-19 16,pbmc,2,0.6746171723485587,0.9178513284655913 +2021-02-19 16,pbmc,5,0.6647481223195565,0.918238273835011 +2021-06-08 11,kidney,6,0.6574942259697635,0.7713399289897841 +2021-06-08 11,kidney,1,0.6491524324360479,0.7683925479354703 +2021-06-08 11,pbmc,4,0.6376143679761396,0.9177092708136871 +2021-06-08 11,kidney,5,0.6507265141297041,0.764045303675037 +2021-06-08 11,kidney,2,0.6536805748853676,0.7694188176691817 +2021-06-08 11,pbmc,3,0.6497117076563463,0.9188563930159294 +2021-06-08 11,pbmc,1,0.6410024329832936,0.9169514442439101 +2021-06-08 11,pbmc,6,0.628524458448013,0.9143017043816851 +2021-06-08 11,pbmc,2,0.6408642380778574,0.9174513476166006 +2021-06-08 11,kidney,3,0.6573702076510592,0.7709389098122496 +2021-06-08 11,kidney,4,0.6527276636519359,0.7714110145008662 +2021-06-08 11,pbmc,5,0.6339190482879701,0.915807205354064 diff --git a/tests/hashsolo_tests.py b/tests/hashsolo_tests.py index 3bb64a3..f448767 100644 --- a/tests/hashsolo_tests.py +++ b/tests/hashsolo_tests.py @@ -9,6 +9,7 @@ def test_cell_demultiplexing(): from scipy import stats import random + random.seed(52) signal = stats.poisson.rvs(1000, 1, 990) doublet_signal = stats.poisson.rvs(1000, 1, 10) @@ -23,30 +24,26 @@ def test_cell_demultiplexing(): test_data = AnnData(x) hashsolo.hashsolo(test_data) - doublets = ['Doublet'] * 10 - classes = list(np.repeat(np.arange(10), 98).reshape(98, 10, - order='F').ravel()) - negatives = ['Negative'] * 10 + doublets = ["Doublet"] * 10 + classes = list(np.repeat(np.arange(10), 98).reshape(98, 10, order="F").ravel()) + negatives = ["Negative"] * 10 classification = doublets + classes + negatives - assert all(test_data.obs['Classification'] == classification) + assert all(test_data.obs["Classification"] == classification) doublets = [2] * 10 classes = [1] * 980 negatives = [0] * 10 classification = doublets + classes + negatives - ll_results = np.argmax(hashsolo._calculate_log_likelihoods(x, 8)[0], - axis=1) + ll_results = np.argmax(hashsolo._calculate_log_likelihoods(x, 8)[0], axis=1) assert all(ll_results == classification) - bayes_results = hashsolo._calculate_bayes_rule(x, [.1, .8, .1], 8) - assert all(bayes_results['most_likely_hypothesis'] == classification) - - singlet_prior = .99999999999999999 - other_prior = (1 - singlet_prior)/2 - bayes_results = hashsolo._calculate_bayes_rule(x, - [other_prior, - singlet_prior, - other_prior], 8) - assert all(bayes_results['most_likely_hypothesis'] == 1) + bayes_results = hashsolo._calculate_bayes_rule(x, [0.1, 0.8, 0.1], 8) + assert all(bayes_results["most_likely_hypothesis"] == classification) + singlet_prior = 0.99999999999999999 + other_prior = (1 - singlet_prior) / 2 + bayes_results = hashsolo._calculate_bayes_rule( + x, [other_prior, singlet_prior, other_prior], 8 + ) + assert all(bayes_results["most_likely_hypothesis"] == 1)