-
Notifications
You must be signed in to change notification settings - Fork 0
/
latent_space_plots_seurat.py
65 lines (50 loc) · 1.99 KB
/
latent_space_plots_seurat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import warnings
warnings.filterwarnings('ignore')
import scanpy as sc
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import glob, os
import matplotlib
import re
import beta_vae_5
sns.set(font_scale=1)
sns.set_style("darkgrid")
def double_feature_to_latent(path,adata,feature,feature2,model,z_dim):
cell_in_latentspace = model.to_latent(adata.X)
df_cols = []
for i in range(z_dim):
df_cols.append(str(i)+'dim')
#print(df_cols)
latent_df = pd.DataFrame(cell_in_latentspace,index=adata.obs[feature],
columns=df_cols)
latent_df.reset_index(level=0, inplace=True)
latent_df[feature2] = list(adata.obs[feature2])
print(latent_df)
path = path+"cells_latent_"+feature+feature2+"/"
try:
os.makedirs(path)
except OSError:
print ("Check if path %s already exists" % path)
else:
print ("Successfully created the directory %s" % path)
os.chdir(path)
latent_df.to_csv("cells_in_latent.csv")
for i in range(z_dim):
dim_col = str(i)+"dim"
latent_df["groups_dim"] = round(latent_df[dim_col],1)
dim0_count = latent_df.groupby(["groups_dim",feature,feature2]).count()
dim0_count = dim0_count.reset_index(level=[0,1,2])
dim0_count = dim0_count.loc[:,[dim_col,"groups_dim",feature,feature2]]
print(dim0_count)
fig, ax = plt.subplots(figsize=(6,6))
scatter = sns.scatterplot(dim0_count["groups_dim"],dim0_count[feature],
size=dim0_count[dim_col].values,hue=dim0_count[feature2],linewidth=0,
sizes=(10, 150))
scatter.set_title("Latent Space for Dimension "+str(i+1), weight="bold")
scatter.set_ylabel(feature.capitalize())
scatter.set_xlabel("Linear scale")
plt.show()
plt.savefig(dim_col+".png", bbox_inches='tight',dpi=100)