Skip to content
This repository has been archived by the owner on Dec 12, 2023. It is now read-only.

Commit

Permalink
improve plots module
Browse files Browse the repository at this point in the history
  • Loading branch information
abearab committed Nov 12, 2023
1 parent cea8bdd commit bb9b1f1
Showing 1 changed file with 160 additions and 157 deletions.
317 changes: 160 additions & 157 deletions screenpro/plots.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# import
import pandas as pd
import numpy as np
from plotnine import * # <-- NOT recommended!
import matplotlib.pyplot as plt
import matplotlib
from phenoScore import ann_score_df

# variables
almost_black = '#111111'
Expand All @@ -22,161 +21,165 @@
'YlBu', [(0, '#0000ff'), (.49, '#000000'), (.51, '#000000'), (1, '#ffff00')])
yellow_blue.set_bad('#999999', 1)

plt.rcParams['font.sans-serif'] = [
'Helvetica', 'Arial', 'Verdana', 'Bitstream Vera Sans'
]
plt.rcParams['font.size'] = 8
plt.rcParams['font.weight'] = 'regular'
plt.rcParams['text.color'] = almost_black

axisLineWidth = .5
plt.rcParams['axes.linewidth'] = axisLineWidth
plt.rcParams['lines.linewidth'] = 1.5

plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['axes.edgecolor'] = almost_black
plt.rcParams['axes.labelcolor'] = almost_black
# plt.rcParams['axes.color_cycle'] = dark2_all

plt.rcParams['patch.edgecolor'] = 'none'
plt.rcParams['patch.linewidth'] = .25
# plt.rcParams['patch.facecolor'] = dark2_all[0]

plt.rcParams['savefig.dpi'] = 1000
plt.rcParams['savefig.format'] = 'svg'

plt.rcParams['legend.frameon'] = False
plt.rcParams['legend.handletextpad'] = .25
plt.rcParams['legend.fontsize'] = 8
plt.rcParams['legend.numpoints'] = 1
plt.rcParams['legend.scatterpoints'] = 1

plt.rcParams['ytick.direction'] = 'out'
plt.rcParams['ytick.color'] = almost_black
plt.rcParams['ytick.major.width'] = axisLineWidth
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['xtick.color'] = almost_black
plt.rcParams['xtick.major.width'] = axisLineWidth


def plot_ggplot_scatter(adata, x, y, color, size, alpha, shape, facet):
"""
Plot scatter plot using ggplot2 style via plotnine.
Args:
adata: anndata object
x: x-axis variable
y: y-axis variable
color: color variable
size: size variable
alpha: alpha variable
shape: shape variable
facet: facet variable
Returns:
plotnine ggplot object
"""
scatter_p = (
ggplot(adata.obs)
+ geom_point(
aes(
x = x,
y = y,
color = color,
size = size,
alpha = alpha,
shape = shape
),
stroke=0.2
)
+ facet_wrap(facet)
+ theme_classic()
+ theme(
panel_grid_major = element_blank(),
panel_grid_minor = element_blank(),
panel_background = element_blank(),

legend_background = element_blank(),
legend_position = 'top',
legend_direction = 'horizontal', # affected by the ncol=2
legend_text_legend = element_text(size=8),

axis_line = element_line(size=2),
axis_text_x = element_text(size=8),
axis_text_y = element_text(size=8),
axis_title_x = element_text(weight='bold', size=12),
axis_title_y = element_text(weight='bold', size=12),

text = element_text(font = 'arial'),

figure_size = (4.5, 5)
)
+ xlim(-60,80)
)
# plt.rcParams['font.sans-serif'] = [
# 'Helvetica', 'Arial', 'Verdana', 'Bitstream Vera Sans'
# ]
# plt.rcParams['font.size'] = 8
# plt.rcParams['font.weight'] = 'regular'
# plt.rcParams['text.color'] = almost_black
#
# axisLineWidth = .5
# plt.rcParams['axes.linewidth'] = axisLineWidth
# plt.rcParams['lines.linewidth'] = 1.5
#
# plt.rcParams['axes.facecolor'] = 'white'
# plt.rcParams['axes.edgecolor'] = almost_black
# plt.rcParams['axes.labelcolor'] = almost_black
# # plt.rcParams['axes.color_cycle'] = dark2_all
#
# plt.rcParams['patch.edgecolor'] = 'none'
# plt.rcParams['patch.linewidth'] = .25
# # plt.rcParams['patch.facecolor'] = dark2_all[0]
#
# plt.rcParams['savefig.dpi'] = 1000
# plt.rcParams['savefig.format'] = 'svg'
#
# plt.rcParams['legend.frameon'] = False
# plt.rcParams['legend.handletextpad'] = .25
# plt.rcParams['legend.fontsize'] = 8
# plt.rcParams['legend.numpoints'] = 1
# plt.rcParams['legend.scatterpoints'] = 1
#
# plt.rcParams['ytick.direction'] = 'out'
# plt.rcParams['ytick.color'] = almost_black
# plt.rcParams['ytick.major.width'] = axisLineWidth
# plt.rcParams['xtick.direction'] = 'out'
# plt.rcParams['xtick.color'] = almost_black
# plt.rcParams['xtick.major.width'] = axisLineWidth


def draw_threshold(x, threshold, pseudo_sd):
return threshold * pseudo_sd * (1 if x > 0 else -1) / abs(x)


def prep_data(df_in, threshold):
df = df_in.copy()

df = ann_score_df(df, threshold=threshold)

df['-log10(pvalue)'] = np.log10(df.pvalue) * -1

return df


def plot_volcano(ax, df_in, threshold, up_hit='resistance_hit', down_hit='sensitivity_hit', xlim_l=-5, xlim_r=5,
ylim=6):
df = prep_data(df_in, threshold)

# Scatter plot for each category
ax.scatter(df.loc[df['label'] == 'target_non_hit', 'score'],
df.loc[df['label'] == 'target_non_hit', '-log10(pvalue)'],
alpha=0.1, s=1, c='black', label='target_non_hit')
ax.scatter(df.loc[df['label'] == up_hit, 'score'], df.loc[df['label'] == up_hit, '-log10(pvalue)'],
alpha=0.9, s=1, c='#fcae91', label=up_hit)
ax.scatter(df.loc[df['label'] == down_hit, 'score'], df.loc[df['label'] == down_hit, '-log10(pvalue)'],
alpha=0.9, s=1, c='#bdd7e7', label=down_hit)
ax.scatter(df.loc[df['label'] == 'non-targeting', 'score'],
df.loc[df['label'] == 'non-targeting', '-log10(pvalue)'],
alpha=0.1, s=1, c='gray', label='non-targeting')

# Set x-axis and y-axis labels
ax.set_xlabel('phenotype score')
ax.set_ylabel('-log10(p-value)')

# Set x-axis limits
ax.set_xlim(xlim_l, xlim_r)

# Set y-axis limits
ax.set_ylim(0.1, ylim)

# Add legend
ax.legend()


def label_as_black(ax, df_in, label, threshold, size=2, size_txt=None, t_x=.5, t_y=-0.1):
df = prep_data(df_in, threshold)

target_data = df[df['target'] == label]

# Scatter plot for labeled data
ax.scatter(target_data['score'], target_data['-log10(pvalue)'],
s=size, linewidth=0.5, edgecolors='black', facecolors='black', label='target')

if not size_txt:
size_txt = size * 2

# Annotate the points
for i, _ in enumerate(target_data['target']):
txt = target_data['target'].iloc[i]
ax.annotate(txt, (target_data['score'].iloc[i] + t_x, target_data['-log10(pvalue)'].iloc[i] + t_y),
color='black', size=size_txt)


def label_sensitivity_hit(ax, df_in, label, threshold, size=2, size_txt=None, t_x=-.5, t_y=-0.1):
df = prep_data(df_in, threshold)

target_data = df[df['target'] == label]

# Scatter plot for labeled data
ax.scatter(target_data['score'], target_data['-log10(pvalue)'],
s=size, linewidth=0.5, edgecolors='black', facecolors='#3182bd', label='target')

if not size_txt:
size_txt = size * 2

# Annotate the points
for i, _ in enumerate(target_data['target']):
txt = target_data['target'].iloc[i]
ax.annotate(txt, (target_data['score'].iloc[i] + t_x, target_data['-log10(pvalue)'].iloc[i] + t_y),
color='black', size=size_txt)


def label_resistance_hit(ax, df_in, label, threshold, size=2, size_txt=None, t_x=.5, t_y=-0.1):
df = prep_data(df_in, threshold)

target_data = df[df['target'] == label]

# Scatter plot for labeled data
ax.scatter(target_data['score'], target_data['-log10(pvalue)'],
s=size, linewidth=0.5, edgecolors='black', facecolors='#de2d26', label='target')

if not size_txt:
size_txt = size * 2

# Annotate the points
for i, _ in enumerate(target_data['target']):
txt = target_data['target'].iloc[i]
ax.annotate(txt, (target_data['score'].iloc[i] + t_x, target_data['-log10(pvalue)'].iloc[i] + t_y),
color='black', size=size_txt)


def plot_replicate_scatter(ax, adata, x, y, title):
bdata = adata[[x, y], :].copy()

bdata.obs.index = [f'Replicate {str(r)}' for r in bdata.obs.replicate.to_list()]
x_lab, y_lab = [f'Replicate {str(r)}' for r in bdata.obs.replicate.to_list()]

return scatter_p


def plot_ggplot_pca(adata):
"""
Plot PCA using ggplot2 style via plotnine.
Args:
adata: anndata object
Returns:
plotnine ggplot object
"""
# Create a dataframe with the PCA coordinates and the metadata
pca = pd.concat([
pd.DataFrame(
adata.obsm['X_pca'][:,[0,1]],
index=adata.obs.index,
columns=['PC-1', 'PC-2']
),
adata.obs.drop('replicate', axis=1)
], axis=1)

pca_p = (
ggplot(pca)
+ geom_point(
aes(
x = 'PC-1',
y = 'PC-2',
fill='score',
# shape = 'treatment'
),
color='black',
size=8
)
+ geom_text(
aes(
x='PC-1',
y='PC-2',
label='score',
size=3
),
nudge_y=3,
nudge_x=7,
)
+ theme_classic()
+ theme(
panel_grid_major = element_blank(),
panel_grid_minor = element_blank(),
panel_background = element_blank(),

legend_background = element_blank(),
legend_position = 'top',
legend_direction = 'horizontal', # affected by the ncol=2
legend_text_legend = element_text(size=8),

axis_line = element_line(size=2),
axis_text_x = element_text(size=8),
axis_text_y = element_text(size=8),
axis_title_x = element_text(weight='bold', size=12),
axis_title_y = element_text(weight='bold', size=12),

text = element_text(font = 'arial'),

figure_size = (4.5, 5)
)
+ xlim(-60,80)
sc.pp.log1p(bdata)
sc.pl.scatter(
bdata,
x_lab, y_lab,
legend_fontsize='xx-large',
palette=[almost_black, '#BFBFBF'],
color='targetType',
title=title,
size=5,
show=False,
ax=ax
)
ax.set_ylim(-1, 11)
ax.set_xlim(-1, 11)
ax.tick_params(axis='both', labelsize=10)
ax.get_legend().remove()

return pca_p
ax.grid(False)

0 comments on commit bb9b1f1

Please sign in to comment.