diff --git a/VizAble/app.py b/VizAble/app.py index 9de5b6b..d2f4b66 100644 --- a/VizAble/app.py +++ b/VizAble/app.py @@ -10,6 +10,8 @@ from shinywidgets import output_widget, render_widget import google.generativeai as genai import os, configparser +import matplotlib.pyplot as plt +import seaborn as sns app_ui = ui.page_navbar( # theme for the app, @@ -46,6 +48,7 @@ def server(input: Inputs, output: Outputs, session: Session): file_check = reactive.value(False) reactive_df: reactive.Value[pd.DataFrame] = reactive.Value(pd.DataFrame()) reactive_dtypes_df: reactive.Value[pd.DataFrame] = reactive.Value(pd.DataFrame()) + decoded_image: str = reactive.value(None) # Step 1: Upload a File @render.ui @@ -585,7 +588,7 @@ def get_output_selected_cols() -> pd.DataFrame: return selected_cols # Step 5: Generate Plots - @render_widget + @render.plot @reactive.event(input.generate) def get_output_plot(): print("The generate plot button is clicked.") @@ -607,146 +610,152 @@ def get_output_plot(): markers = input.markers() # color_by = input.color_by() - line_plot = px.line( - data_frame = data_frame, + # line_plot = px.line( + # data_frame = data_frame, + # x = input.line_x_axis(), + # y = input.line_y_axis(), + # markers = markers, + # # color = color_by, + # ).update_layout( + # template="seaborn", + # title={"text": plot_title, "x": 0.5}, + # ).update_xaxes( + # title_text = x_axis_title, + # ).update_yaxes( + # title_text = y_axis_title, + # ) + sns.set_theme() + plt.figure(figsize=(10, 6)) + line_plot = sns.lineplot( + data = data_frame, x = input.line_x_axis(), y = input.line_y_axis(), - markers = markers, - # color = color_by, - ).update_layout( - template="seaborn", - title={"text": plot_title, "x": 0.5}, - ).update_xaxes( - title_text = x_axis_title, - ).update_yaxes( - title_text = y_axis_title, - ) - - return line_plot + markers = input.markers(), + ) + plt.title(plot_title) + plt.xlabel(x_axis_title) + plt.ylabel(y_axis_title) + # plt.show() + return_plot = line_plot - # Bar Plot: - if input.plot_types() == "Bar Plot": - req(input.bar_x_axis()) - plot_title = input.bar_plot_title() - x_axis = input.bar_x_axis() - x_axis_title=input.bar_x_axis_title() - y_axis_title=input.bar_y_axis_title() - - # generate dataframe for value counts - counts_df = data_frame[x_axis].value_counts().reset_index() - - # rename columns name - counts_df.columns = ['value', 'count'] - - bar_plot = px.bar( - data_frame = counts_df, - x = "value", - y = "count", - color="value", - ).update_layout( - template = "seaborn", - title = {"text": plot_title, "x": 0.5}, - ).update_xaxes( - title_text = x_axis_title, - ).update_yaxes( - title_text = y_axis_title, - ) - - return bar_plot + # # Bar Plot: + # if input.plot_types() == "Bar Plot": + # req(input.bar_x_axis()) + # plot_title = input.bar_plot_title() + # x_axis = input.bar_x_axis() + # x_axis_title=input.bar_x_axis_title() + # y_axis_title=input.bar_y_axis_title() + + # # generate dataframe for value counts + # counts_df = data_frame[x_axis].value_counts().reset_index() + + # # rename columns name + # counts_df.columns = ['value', 'count'] + + # bar_plot = px.bar( + # data_frame = counts_df, + # x = "value", + # y = "count", + # color="value", + # ).update_layout( + # template = "seaborn", + # title = {"text": plot_title, "x": 0.5}, + # ).update_xaxes( + # title_text = x_axis_title, + # ).update_yaxes( + # title_text = y_axis_title, + # ) + # return_plot = bar_plot - # Box Plot: - if input.plot_types() == "Box Plot": - req(input.box_y_axis()) - # y_axis = input.box_y_axis() - plot_title = input.box_plot_title() - y_axis_title = input.box_y_axis_title() - - box_plot = px.box( - data_frame = data_frame, - y = input.box_y_axis(), - ).update_layout( - template="seaborn", - title={"text": plot_title, "x": 0.5}, - ).update_yaxes( - title_text = y_axis_title, - ) - - return box_plot + # # Box Plot: + # if input.plot_types() == "Box Plot": + # req(input.box_y_axis()) + # # y_axis = input.box_y_axis() + # plot_title = input.box_plot_title() + # y_axis_title = input.box_y_axis_title() + + # box_plot = px.box( + # data_frame = data_frame, + # y = input.box_y_axis(), + # ).update_layout( + # template="seaborn", + # title={"text": plot_title, "x": 0.5}, + # ).update_yaxes( + # title_text = y_axis_title, + # ) + # return_plot = box_plot - # Grouped_Box Plot: - if input.plot_types() == "Grouped_Box Plot": - req(input.grouped_box_y_axis(), input.grouped_box_grouping()) - plot_title = input.grouped_box_plot_title() - y_axis_title = input.grouped_box_y_axis_title() - grouping = input.grouped_box_grouping() - - box_plot_grouped = px.box( - data_frame = data_frame, - x = grouping, - y = input.grouped_box_y_axis(), - color = grouping, - ).update_layout( - template="seaborn", - title={"text": plot_title, "x": 0.5}, - ).update_yaxes( - title_text = y_axis_title, - ) - - return box_plot_grouped - - # Histogram: - if input.plot_types() == "Histogram": - req(input.histogram_x_axis()) - plot_title = input.histogram_plot_title() - - # x_axis title - x_axis_title = input.histogram_x_axis_title() - y_axis_title = input.histogram_y_axis_title() - - histogram = px.histogram( - data_frame = data_frame, - x = input.histogram_x_axis(), - nbins = input.histogram_bin_size(), - ).update_layout( - template="seaborn", - title={"text": plot_title, "x": 0.5}, - ).update_xaxes( - title_text = x_axis_title, - ).update_yaxes( - title_text = y_axis_title, - ) - - return histogram - - # Scatter Plot: - if input.plot_types() == "Scatter Plot": - req(input.scatter_x_axis(), input.scatter_y_axis()) - plot_title = input.scatter_plot_title() - - # x_axis + yaxis title - x_axis_title = input.scatter_x_axis_title() - y_axis_title = input.scatter_y_axis_title() - - scatter_plot = px.scatter( - data_frame = data_frame, - x = input.scatter_x_axis(), - y = input.scatter_y_axis(), - ).update_layout( - template="seaborn", - title={"text": plot_title, "x": 0.5}, - ).update_xaxes( - title_text = x_axis_title, - ).update_yaxes( - title_text = y_axis_title, - ) - - # ui.update_select( - # id = "scatter_color_by", - # label = "Group by(color)", - # choices = color_by_choices - # ) - return scatter_plot - + # # Grouped_Box Plot: + # if input.plot_types() == "Grouped_Box Plot": + # req(input.grouped_box_y_axis(), input.grouped_box_grouping()) + # plot_title = input.grouped_box_plot_title() + # y_axis_title = input.grouped_box_y_axis_title() + # grouping = input.grouped_box_grouping() + + # box_plot_grouped = px.box( + # data_frame = data_frame, + # x = grouping, + # y = input.grouped_box_y_axis(), + # color = grouping, + # ).update_layout( + # template="seaborn", + # title={"text": plot_title, "x": 0.5}, + # ).update_yaxes( + # title_text = y_axis_title, + # ) + # return_plot = box_plot_grouped + + # # Histogram: + # if input.plot_types() == "Histogram": + # req(input.histogram_x_axis()) + # plot_title = input.histogram_plot_title() + + # # x_axis title + # x_axis_title = input.histogram_x_axis_title() + # y_axis_title = input.histogram_y_axis_title() + + # histogram = px.histogram( + # data_frame = data_frame, + # x = input.histogram_x_axis(), + # nbins = input.histogram_bin_size(), + # ).update_layout( + # template="seaborn", + # title={"text": plot_title, "x": 0.5}, + # ).update_xaxes( + # title_text = x_axis_title, + # ).update_yaxes( + # title_text = y_axis_title, + # ) + # return_plot = histogram + + # # Scatter Plot: + # if input.plot_types() == "Scatter Plot": + # req(input.scatter_x_axis(), input.scatter_y_axis()) + # plot_title = input.scatter_plot_title() + + # # x_axis + yaxis title + # x_axis_title = input.scatter_x_axis_title() + # y_axis_title = input.scatter_y_axis_title() + + # scatter_plot = px.scatter( + # data_frame = data_frame, + # x = input.scatter_x_axis(), + # y = input.scatter_y_axis(), + # ).update_layout( + # template="seaborn", + # title={"text": plot_title, "x": 0.5}, + # ).update_xaxes( + # title_text = x_axis_title, + # ).update_yaxes( + # title_text = y_axis_title, + # ) + + # return_plot = scatter_plot + + decoded_image.set(functions.decode_plot(return_plot)) + print(decoded_image.get()) + return return_plot + # Step 6: Chatbot @render.text @reactive.event(input.ask) @@ -768,13 +777,26 @@ def get_chatbot_output(): ) # initialize the model - model = genai.GenerativeModel("gemini-pro") + model = genai.GenerativeModel("gemini-pro-vision") chat = model.start_chat(history = []) question = input.chatbot_input() + # Send the encoded image to Gemini for analysis + prompt="Could you please provide a summary of this plot?" + # generate content - response = chat.send_message(question) - return response.text + if question is None: + response = model.generate_content([decoded_image.get(), prompt]) + else: + response = model.generate_content([decoded_image.get(), question]) + + # Construct the HTML string for the ARIA live region + # Use aria-live="assertive" to announce updates immediately + live_region_html = ui.tags.div(response.text, style="position: absolute; left: -9999px;", aria_live="assertive") + + # Return the HTML string containing the ARIA live region + # return ui.HTML(live_region_html) + return live_region_html app = App(app_ui, server) diff --git a/VizAble/chatbot.py b/VizAble/chatbot.py index c31ae75..899ad38 100644 --- a/VizAble/chatbot.py +++ b/VizAble/chatbot.py @@ -30,7 +30,7 @@ def chatbot_ui() -> ui.nav_panel: width="100%", ), ), - ui.output_text_verbatim("get_chatbot_output"), + ui.output_text("get_chatbot_output"), ), height="80vh", ), diff --git a/VizAble/functions.py b/VizAble/functions.py index f85a662..2be67ce 100644 --- a/VizAble/functions.py +++ b/VizAble/functions.py @@ -4,6 +4,10 @@ import openpyxl import pandas as pd from pandas.errors import ParserError +import plotly.io as pio +import base64, io +import matplotlib.pyplot as plt +from PIL import Image def sep_input_radio_buttons() -> ui.input_radio_buttons: """ Create a radio button group for users to select a separator for input. @@ -248,4 +252,18 @@ def update_grouping_input_select(plot_type: str, choices: List[str]) -> ui.updat id=grouping_id, choices=choices, selected=None - ) \ No newline at end of file + ) + +def decode_plot(return_plot): + # Get the current Matplotlib figure associated with the plot + fig = return_plot.get_figure() + + # Save the plot as a PNG image in memory + buffer = io.BytesIO() + fig.savefig(buffer, format='png') + buffer.seek(0) + + # Decode the image + decoded_image = Image.open(buffer) + + return decoded_image \ No newline at end of file diff --git a/VizAble/generate_plots.py b/VizAble/generate_plots.py index 7056bbb..886acc0 100644 --- a/VizAble/generate_plots.py +++ b/VizAble/generate_plots.py @@ -148,7 +148,8 @@ def generate_plots_ui() -> ui.nav_panel: "You can generate a plot by clicking this \"Generate Plot\" button, but if you have modify anything related to this plot, you need to click this button again to update the plot.", ), ), - output_widget("get_output_plot"), + + ui.output_plot("get_output_plot"), height="80vh", ), ),