Skip to content

Commit

Permalink
Merge pull request #190 from Rebekah-Chuang:Rebekah-Chuang/issue189
Browse files Browse the repository at this point in the history
feat: encode plot with base64 and implement ARIA live
  • Loading branch information
Rebekah-Chuang authored Apr 19, 2024
2 parents 584267a + 310f890 commit c30fbea
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 142 deletions.
300 changes: 161 additions & 139 deletions VizAble/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from shinywidgets import output_widget, render_widget
import google.generativeai as genai
import os, configparser
import matplotlib.pyplot as plt
import seaborn as sns

app_ui = ui.page_navbar(
# theme for the app,
Expand Down Expand Up @@ -46,6 +48,7 @@ def server(input: Inputs, output: Outputs, session: Session):
file_check = reactive.value(False)
reactive_df: reactive.Value[pd.DataFrame] = reactive.Value(pd.DataFrame())
reactive_dtypes_df: reactive.Value[pd.DataFrame] = reactive.Value(pd.DataFrame())
decoded_image: str = reactive.value(None)

# Step 1: Upload a File
@render.ui
Expand Down Expand Up @@ -585,7 +588,7 @@ def get_output_selected_cols() -> pd.DataFrame:
return selected_cols

# Step 5: Generate Plots
@render_widget
@render.plot
@reactive.event(input.generate)
def get_output_plot():
print("The generate plot button is clicked.")
Expand All @@ -607,146 +610,152 @@ def get_output_plot():
markers = input.markers()
# color_by = input.color_by()

line_plot = px.line(
data_frame = data_frame,
# line_plot = px.line(
# data_frame = data_frame,
# x = input.line_x_axis(),
# y = input.line_y_axis(),
# markers = markers,
# # color = color_by,
# ).update_layout(
# template="seaborn",
# title={"text": plot_title, "x": 0.5},
# ).update_xaxes(
# title_text = x_axis_title,
# ).update_yaxes(
# title_text = y_axis_title,
# )
sns.set_theme()
plt.figure(figsize=(10, 6))
line_plot = sns.lineplot(
data = data_frame,
x = input.line_x_axis(),
y = input.line_y_axis(),
markers = markers,
# color = color_by,
).update_layout(
template="seaborn",
title={"text": plot_title, "x": 0.5},
).update_xaxes(
title_text = x_axis_title,
).update_yaxes(
title_text = y_axis_title,
)

return line_plot
markers = input.markers(),
)
plt.title(plot_title)
plt.xlabel(x_axis_title)
plt.ylabel(y_axis_title)
# plt.show()
return_plot = line_plot

# Bar Plot:
if input.plot_types() == "Bar Plot":
req(input.bar_x_axis())
plot_title = input.bar_plot_title()
x_axis = input.bar_x_axis()
x_axis_title=input.bar_x_axis_title()
y_axis_title=input.bar_y_axis_title()

# generate dataframe for value counts
counts_df = data_frame[x_axis].value_counts().reset_index()

# rename columns name
counts_df.columns = ['value', 'count']

bar_plot = px.bar(
data_frame = counts_df,
x = "value",
y = "count",
color="value",
).update_layout(
template = "seaborn",
title = {"text": plot_title, "x": 0.5},
).update_xaxes(
title_text = x_axis_title,
).update_yaxes(
title_text = y_axis_title,
)

return bar_plot
# # Bar Plot:
# if input.plot_types() == "Bar Plot":
# req(input.bar_x_axis())
# plot_title = input.bar_plot_title()
# x_axis = input.bar_x_axis()
# x_axis_title=input.bar_x_axis_title()
# y_axis_title=input.bar_y_axis_title()

# # generate dataframe for value counts
# counts_df = data_frame[x_axis].value_counts().reset_index()

# # rename columns name
# counts_df.columns = ['value', 'count']

# bar_plot = px.bar(
# data_frame = counts_df,
# x = "value",
# y = "count",
# color="value",
# ).update_layout(
# template = "seaborn",
# title = {"text": plot_title, "x": 0.5},
# ).update_xaxes(
# title_text = x_axis_title,
# ).update_yaxes(
# title_text = y_axis_title,
# )
# return_plot = bar_plot

# Box Plot:
if input.plot_types() == "Box Plot":
req(input.box_y_axis())
# y_axis = input.box_y_axis()
plot_title = input.box_plot_title()
y_axis_title = input.box_y_axis_title()

box_plot = px.box(
data_frame = data_frame,
y = input.box_y_axis(),
).update_layout(
template="seaborn",
title={"text": plot_title, "x": 0.5},
).update_yaxes(
title_text = y_axis_title,
)

return box_plot
# # Box Plot:
# if input.plot_types() == "Box Plot":
# req(input.box_y_axis())
# # y_axis = input.box_y_axis()
# plot_title = input.box_plot_title()
# y_axis_title = input.box_y_axis_title()

# box_plot = px.box(
# data_frame = data_frame,
# y = input.box_y_axis(),
# ).update_layout(
# template="seaborn",
# title={"text": plot_title, "x": 0.5},
# ).update_yaxes(
# title_text = y_axis_title,
# )
# return_plot = box_plot

# Grouped_Box Plot:
if input.plot_types() == "Grouped_Box Plot":
req(input.grouped_box_y_axis(), input.grouped_box_grouping())
plot_title = input.grouped_box_plot_title()
y_axis_title = input.grouped_box_y_axis_title()
grouping = input.grouped_box_grouping()

box_plot_grouped = px.box(
data_frame = data_frame,
x = grouping,
y = input.grouped_box_y_axis(),
color = grouping,
).update_layout(
template="seaborn",
title={"text": plot_title, "x": 0.5},
).update_yaxes(
title_text = y_axis_title,
)

return box_plot_grouped

# Histogram:
if input.plot_types() == "Histogram":
req(input.histogram_x_axis())
plot_title = input.histogram_plot_title()

# x_axis title
x_axis_title = input.histogram_x_axis_title()
y_axis_title = input.histogram_y_axis_title()

histogram = px.histogram(
data_frame = data_frame,
x = input.histogram_x_axis(),
nbins = input.histogram_bin_size(),
).update_layout(
template="seaborn",
title={"text": plot_title, "x": 0.5},
).update_xaxes(
title_text = x_axis_title,
).update_yaxes(
title_text = y_axis_title,
)

return histogram

# Scatter Plot:
if input.plot_types() == "Scatter Plot":
req(input.scatter_x_axis(), input.scatter_y_axis())
plot_title = input.scatter_plot_title()

# x_axis + yaxis title
x_axis_title = input.scatter_x_axis_title()
y_axis_title = input.scatter_y_axis_title()

scatter_plot = px.scatter(
data_frame = data_frame,
x = input.scatter_x_axis(),
y = input.scatter_y_axis(),
).update_layout(
template="seaborn",
title={"text": plot_title, "x": 0.5},
).update_xaxes(
title_text = x_axis_title,
).update_yaxes(
title_text = y_axis_title,
)

# ui.update_select(
# id = "scatter_color_by",
# label = "Group by(color)",
# choices = color_by_choices
# )
return scatter_plot

# # Grouped_Box Plot:
# if input.plot_types() == "Grouped_Box Plot":
# req(input.grouped_box_y_axis(), input.grouped_box_grouping())
# plot_title = input.grouped_box_plot_title()
# y_axis_title = input.grouped_box_y_axis_title()
# grouping = input.grouped_box_grouping()

# box_plot_grouped = px.box(
# data_frame = data_frame,
# x = grouping,
# y = input.grouped_box_y_axis(),
# color = grouping,
# ).update_layout(
# template="seaborn",
# title={"text": plot_title, "x": 0.5},
# ).update_yaxes(
# title_text = y_axis_title,
# )
# return_plot = box_plot_grouped

# # Histogram:
# if input.plot_types() == "Histogram":
# req(input.histogram_x_axis())
# plot_title = input.histogram_plot_title()

# # x_axis title
# x_axis_title = input.histogram_x_axis_title()
# y_axis_title = input.histogram_y_axis_title()

# histogram = px.histogram(
# data_frame = data_frame,
# x = input.histogram_x_axis(),
# nbins = input.histogram_bin_size(),
# ).update_layout(
# template="seaborn",
# title={"text": plot_title, "x": 0.5},
# ).update_xaxes(
# title_text = x_axis_title,
# ).update_yaxes(
# title_text = y_axis_title,
# )
# return_plot = histogram

# # Scatter Plot:
# if input.plot_types() == "Scatter Plot":
# req(input.scatter_x_axis(), input.scatter_y_axis())
# plot_title = input.scatter_plot_title()

# # x_axis + yaxis title
# x_axis_title = input.scatter_x_axis_title()
# y_axis_title = input.scatter_y_axis_title()

# scatter_plot = px.scatter(
# data_frame = data_frame,
# x = input.scatter_x_axis(),
# y = input.scatter_y_axis(),
# ).update_layout(
# template="seaborn",
# title={"text": plot_title, "x": 0.5},
# ).update_xaxes(
# title_text = x_axis_title,
# ).update_yaxes(
# title_text = y_axis_title,
# )

# return_plot = scatter_plot

decoded_image.set(functions.decode_plot(return_plot))
print(decoded_image.get())
return return_plot

# Step 6: Chatbot
@render.text
@reactive.event(input.ask)
Expand All @@ -768,13 +777,26 @@ def get_chatbot_output():
)

# initialize the model
model = genai.GenerativeModel("gemini-pro")
model = genai.GenerativeModel("gemini-pro-vision")
chat = model.start_chat(history = [])
question = input.chatbot_input()

# Send the encoded image to Gemini for analysis
prompt="Could you please provide a summary of this plot?"

# generate content
response = chat.send_message(question)
return response.text
if question is None:
response = model.generate_content([decoded_image.get(), prompt])
else:
response = model.generate_content([decoded_image.get(), question])

# Construct the HTML string for the ARIA live region
# Use aria-live="assertive" to announce updates immediately
live_region_html = ui.tags.div(response.text, style="position: absolute; left: -9999px;", aria_live="assertive")

# Return the HTML string containing the ARIA live region
# return ui.HTML(live_region_html)
return live_region_html

app = App(app_ui, server)

Expand Down
2 changes: 1 addition & 1 deletion VizAble/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def chatbot_ui() -> ui.nav_panel:
width="100%",
),
),
ui.output_text_verbatim("get_chatbot_output"),
ui.output_text("get_chatbot_output"),
),
height="80vh",
),
Expand Down
20 changes: 19 additions & 1 deletion VizAble/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import openpyxl
import pandas as pd
from pandas.errors import ParserError
import plotly.io as pio
import base64, io
import matplotlib.pyplot as plt
from PIL import Image

def sep_input_radio_buttons() -> ui.input_radio_buttons:
""" Create a radio button group for users to select a separator for input.
Expand Down Expand Up @@ -248,4 +252,18 @@ def update_grouping_input_select(plot_type: str, choices: List[str]) -> ui.updat
id=grouping_id,
choices=choices,
selected=None
)
)

def decode_plot(return_plot):
# Get the current Matplotlib figure associated with the plot
fig = return_plot.get_figure()

# Save the plot as a PNG image in memory
buffer = io.BytesIO()
fig.savefig(buffer, format='png')
buffer.seek(0)

# Decode the image
decoded_image = Image.open(buffer)

return decoded_image
3 changes: 2 additions & 1 deletion VizAble/generate_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def generate_plots_ui() -> ui.nav_panel:
"You can generate a plot by clicking this \"Generate Plot\" button, but if you have modify anything related to this plot, you need to click this button again to update the plot.",
),
),
output_widget("get_output_plot"),

ui.output_plot("get_output_plot"),
height="80vh",
),
),
Expand Down

0 comments on commit c30fbea

Please sign in to comment.