Skip to content

Commit

Permalink
Improved Shiny app
Browse files Browse the repository at this point in the history
  • Loading branch information
itrujnara committed Jun 21, 2024
1 parent 714aa45 commit 47000f0
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 65 deletions.
269 changes: 219 additions & 50 deletions bin/shiny_app/shiny_app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# This app is translated from Mastering Shinywidgets
# https://mastering-shiny.org/basic-reactivity.html#reactive-expressions-1
from shiny import App, render, ui
from shinywidgets import output_widget, render_widget
from numpy import random
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path
import sys
import os
import shiny_app_merge_score_and_trace as ms
from mythemes import theme_light, theme_dark, theme_contrast

# Style
sns.set(context="talk", style="white", font_scale=0.8)
sns.set_theme(context="talk", style="dark", font_scale=0.8)


# Load file
Expand All @@ -32,97 +38,260 @@
print("ERROR: file not found: ", summary_report)
sys.exit(1)

def merge_tree_args(row):
if str(row["tree"]) == "nan":
return "None"
elif str(row["args_tree"]) == "nan":
return str(row["tree"]) + " ()"
else:
return str(row["tree"]) + " (" + str(row["args_tree"]) + ")"

inputfile["tree_args"] = inputfile.apply(merge_tree_args, axis=1)

def merge_aligner_args(row):
if str(row["aligner"]) == "nan":
return "None"
elif str(row["args_aligner"]) == "nan":
return str(row["aligner"]) + " ()"
else:
return str(row["aligner"]) + " (" + str(row["args_aligner"]) + ")"

inputfile["aligner_args"] = inputfile.apply(merge_aligner_args, axis=1)


# ----------------------------------------------------------------------------

options = {item: item for item in list(inputfile.columns)}
options_color = {"aligner": "assembly", "tree": "tree"}

options_color = {"aligner": "Assembly",
"aligner_args": "Assembly with args",
"tree": "Tree",
"tree_args": "Tree with args"}

options_eval = {
"sp": "sum of pairs (SP)",
"n_sequences": "# sequences",
"tc": "total column score (TC)",
"perc_sim": "sequences avg similarity",
"seq_length_mean": "sequence length (mean)",
"time_tree": "tree time (min)",
"time_align": "alignment time (min)",
"memory_tree": "tree memory (GB)",
"memory_align": "alignment memory (GB)",
"avg_plddt": "avg pLDDT"
"sp": "Sum of Pairs (SP)",
"n_sequences": "Number of Sequences",
"tc": "Total Column Score (TC)",
"perc_sim": "Average Sequence Similarity (%)",
"seq_length_mean": "Mean Sequence Length",
"time_tree": "Tree Building Time (min)",
"time_align": "Alignment Time (min)",
"memory_tree": "Tree Building Memory (GB)",
"memory_align": "Alignment Memory (GB)",
"avg_plddt": "Average pLDDT",
"aligner": "Assembly",
"aligner_args": "Assembly with args",
"tree": "Tree",
"tree_args": "Tree with args"
}

vars_cat = ["aligner", "tree", "tree_args", "aligner_args"]

options_theme = {
"plotly": "Default",
"plotly_white": "Light",
"plotly_dark": "Dark"
}

palettes = {
"theme_light": "pastel",
"theme_dark": "deep",
"theme_contrast": "bright"
}

xlims = {
"sp": [0, 100],
"tc": [0, 100],
"perc_sim": [0, 100],
"tcs": [0, 1000],
"plddt": [0, 100]
}

app_ui = ui.page_fluid(
# HEAD
# Links
ui.tags.link(
rel="stylesheet",
href="bootstrap.min.css"
),
ui.tags.link(
rel="stylesheet",
href="style.css"
),
############################
# BODY
# Header
ui.column(
5,
{"class": "col-md-10 col-lg-8 py-5 mx-auto text-lg-center text-left"},
# Title
ui.h1("Explore the benchmarking results"),
ui.h1({"class": "fw-bold"},
ui.span({"class": "text-primary"}, "nf-core/"), "multiplesequencealign"),
# Subtitle
ui.h2({"class": "text-muted"}, "Stats & Evaluation Explorer"),
# input slider
),
ui.row(
{"class": "col-md-10 col-lg-8 py-5 mx-auto text-lg-center text-left"},
ui.column(
4,
# Main body
ui.layout_sidebar(
# Sidebar
ui.sidebar(
# Mappings heading
ui.h3("Mappings"),
# X axis input
ui.input_select(
"x",
"X axis: ",
{
"x axis": options_eval,
"X axis": options_eval,
},
selected="n_sequences",
selected="n_sequences",
),
),
ui.column(
4,
# Y axis input
ui.input_select(
"y",
"Y axis: ",
{
"y axis": options_eval,
"Y axis": options_eval,
},
selected="sp",
),
),
ui.column(
4,
# Color input
ui.input_select(
"color",
"color: ",
"Color: ",
{
"color": options_color,
"Color": options_color,
},
selected="align",
),
# Linear model checkbox
ui.input_checkbox("lm", "Show linear model (scatterplot)", value=False),
# Style heading
ui.h3("Style"),
# General
ui.h4("General"),
ui.input_select(
"theme",
"Theme: ",
{
"Theme": options_theme,
},
selected="Dark"
),
# Scatter plot
ui.h4("Scatter plot"),
# Point size input
ui.input_numeric("size", "Point size: ", min=1, max=100, step=10, value=60)
),
ui.column(
4,
ui.input_numeric("size", "dot's size: ", min=1, max=100, step=10, value=60),
),
),
ui.row(
ui.column(
4, {"class": "col-md-40 col-lg-25 py-10 mx-auto text-lg-center text-left"}, ui.output_plot("scatter")
),
),
# Plots
ui.navset_tab(
ui.nav_panel(
"Scatter plot",
ui.column(
5,
{"class": "col-md-40 col-lg-25 py-10 mx-auto text-lg-center text-left"},
output_widget("autoplot", width = "clamp(400px, 50vw, 800px)", height = "clamp(300px, 40vh, 600px)")
)
),
ui.nav_panel(
"Correlation",
ui.column(
5,
{"class": "col-md-40 col-lg-25 py-10 mx-auto text-lg-center text-left"},
output_widget("corr", width = "clamp(400px, 50vw, 800px)", height = "clamp(400px, 50vh, 800px)")
)
)
)
)
)


def server(input, output, session):
@output
@render.plot
def scatter():
plt.ylim(0, 100)
plt.xlim(0, 100)
@render_widget
def autoplot():
if input.x() in vars_cat and input.y() in vars_cat: # heatmap
return heatmap()
elif input.x() in vars_cat: # vertical boxplot
return boxplot_vertical()
elif input.y() in vars_cat: # horizontal boxplot
return boxplot_horizontal()
else: # scatterplot
return scatterplot()

def heatmap():
x = input.x()
y = input.y()
xtab = pd.crosstab(inputfile[x], inputfile[y])
fig = px.imshow(xtab, x=xtab.columns, y=xtab.index, text_auto=True)
fig.update_layout(
template = input.theme(),
xaxis_title=options_eval.get(y, y),
yaxis_title=options_eval.get(x, x)
)
return fig

def boxplot_horizontal():
x = input.x()
y = input.y()
fig = px.box(inputfile.fillna(''), x=x, y=y, color=y)

fig.update_layout(
template = input.theme(),
xaxis_title=options_eval.get(x, x),
yaxis_title=options_eval.get(y, y),
legend_title_text=options_eval.get(y, y)
)

x_label = options_eval[input.x()]
y_label = options_eval[input.y()]
return fig

ax = sns.scatterplot(data=inputfile, x=input.x(), y=input.y(), hue=input.color(), s=input.size())
def boxplot_vertical():
x = input.x()
y = input.y()

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
fig = px.box(inputfile.fillna(""), x=x, y=y, color=x)

fig.update_layout(
template = input.theme(),
xaxis_title=options_eval.get(x, x),
yaxis_title=options_eval.get(y, y),
legend_title_text=options_eval.get(x, x)
)

return fig

def scatterplot():
x = input.x()
y = input.y()
color = input.color()
size = input.size()

fig = px.scatter(inputfile, x=x, y=y, color=color, trendline="ols" if input.lm() else None, trendline_scope="overall")

fig.update_traces(marker=dict(size=size/5))

fig.update_layout(
template = input.theme(),
xaxis_title=options_eval.get(x, x),
yaxis_title=options_eval.get(y, y),
xaxis = dict(range = xlims.get(x, [0, None])),
yaxis = dict(range = xlims.get(y, [0, None]))
)

return fig


@output
@render_widget
def corr():
data = inputfile[list(set(options_eval.keys()) & set(inputfile.columns) - set(vars_cat))]
corr = data.corr()
xlabs = [options_eval.get(x, x) for x in corr.columns]
ylabs = [options_eval.get(y, y) for y in corr.index]

plt.legend(bbox_to_anchor=(1.05, 1), loc=3, borderaxespad=0.0)
return ax
fig = px.imshow(corr, x=xlabs, y=ylabs, text_auto=".2f", labels = options_eval)
return fig


app = App(app_ui, server)
app_dir = Path(__file__).parent
app = App(app_ui, server, static_assets = app_dir / "static")
35 changes: 20 additions & 15 deletions bin/shiny_app/shiny_app_merge_score_and_trace.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import pandas as pd
import re

def convert_time(time):
if time is not None:
if "ms" in time:
time = time.replace('ms', '')
time = float(time)/60000
elif "s" in time:
time = time.replace('s', '')
time = float(time)/60
elif "m" in time:
time = time.replace('m', '')
elif "h" in time:
time = time.replace('h', '')
time = float(time)*60
return time
def convert_time(time_str):
# Regular expression to match the time components
pattern = re.compile(r'((?P<hours>\d+)h)?\s*((?P<minutes>\d+)m)?\s*((?P<seconds>\d+)s)?\s*((?P<milliseconds>\d+)ms)?')
match = pattern.fullmatch(time_str.strip())

if not match:
raise ValueError("Time string is not in the correct format")

time_components = match.groupdict(default='0')

hours = int(time_components['hours'])
minutes = int(time_components['minutes'])
seconds = int(time_components['seconds'])
milliseconds = int(time_components['milliseconds'])

# Convert everything to minutes
total_minutes = (hours * 60) + minutes + (seconds / 60) + (milliseconds / 60000)

return total_minutes

def convert_memory(memory):
# from anything to GB
Expand Down Expand Up @@ -96,4 +102,3 @@ def merge_data_and_trace(data_file,trace_file,out_file_name):

# write to file
data_tree_align.to_csv(out_file_name, index=False)

1 change: 1 addition & 0 deletions bin/shiny_app/static/bootstrap.min.css

Large diffs are not rendered by default.

Loading

0 comments on commit 47000f0

Please sign in to comment.