Skip to content

Commit

Permalink
docs(layout_columns): Add example app (#903)
Browse files Browse the repository at this point in the history
  • Loading branch information
gadenbuie authored Dec 18, 2023
1 parent b8a2316 commit 8ad1817
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
33 changes: 33 additions & 0 deletions shiny/api-examples/layout_columns/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from model_plots import * # model plots and cards

from shiny import App, Inputs, Outputs, Session, render, ui

app_ui = ui.page_fluid(
ui.panel_title(ui.h2("Model Dashboard")),
ui.markdown("Using `ui.layout_columns()` for the layout."),
ui.layout_columns(
card_loss,
card_acc,
card_feat,
col_widths={"sm": (5, 7, 12)},
# row_heights=(2, 3),
# height="700px",
),
)


def server(input: Inputs, output: Outputs, session: Session):
@render.plot
def loss_over_time():
return plot_loss_over_time()

@render.plot
def accuracy_over_time():
return plot_accuracy_over_time()

@render.plot
def feature_importance():
return plot_feature_importance()


app = App(app_ui, server)
56 changes: 56 additions & 0 deletions shiny/api-examples/layout_columns/model_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import matplotlib.pyplot as plt
import numpy as np

from shiny import ui


def plot_loss_over_time():
epochs = np.arange(1, 101)
loss = 1000 / np.sqrt(epochs) + np.random.rand(100) * 25

fig = plt.figure(figsize=(10, 6))
plt.plot(epochs, loss)
plt.xlabel("Epochs")
plt.ylabel("Loss")
return fig


def plot_accuracy_over_time():
epochs = np.arange(1, 101)
accuracy = np.sqrt(epochs) / 12 + np.random.rand(100) * 0.15
accuracy = [np.min([np.max(accuracy[:i]), 1]) for i in range(1, 101)]

fig = plt.figure(figsize=(10, 6))
plt.plot(epochs, accuracy)
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
return fig


def plot_feature_importance():
features = ["Product Category", "Price", "Brand", "Rating", "Number of Reviews"]
importance = np.random.rand(5)

fig = plt.figure(figsize=(10, 6))
plt.barh(features, importance)
plt.xlabel("Importance")
return fig


card_loss = ui.card(
ui.card_header("Loss Over Time"),
ui.output_plot("loss_over_time"),
full_screen=True,
)

card_acc = ui.card(
ui.card_header("Accuracy Over Time"),
ui.output_plot("accuracy_over_time"),
full_screen=True,
)

card_feat = ui.card(
ui.card_header("Feature Importance"),
ui.output_plot("feature_importance"),
full_screen=True,
)

0 comments on commit 8ad1817

Please sign in to comment.