diff --git a/bibmon/_bibmon_tools.py b/bibmon/_bibmon_tools.py index fd598d9..ceca675 100644 --- a/bibmon/_bibmon_tools.py +++ b/bibmon/_bibmon_tools.py @@ -692,4 +692,76 @@ def comparative_table (models, X_train, X_validation, X_test, return_tables.append(times_df) - return return_tables \ No newline at end of file + return return_tables + +############################################################################## + +def filter_and_plot_data(data, color, title, not_wanted_tags = [], remove_NaN_columns = True, remove_zero_columns = True): + """ + Filters and plots data from a dictionary, concatenating the values into a single DataFrame. + Optionally removes columns that contain only NaN values or only zeros. The remaining data is + plotted with each column displayed in individual vertically stacked subplots, sharing the same X-axis. + + Parameters + ---------- + data: dict + A dictionary containing the data. The keys represent categories or filenames, and + the values are pandas DataFrames. + color: string + The color to be used for the plot lines. + title: string + The main title for the plot. + not_wanted_tags: list, optional + A list of column names (tags) to be excluded from the plots. Default is an empty list. + remove_NaN_columns: bool, optional + If True, columns with only NaN values will be removed from the data before plotting. Default is True. + remove_zero_columns: bool, optional + If True, columns that contain only zeros will be removed from the data before plotting. Default is True. + + Returns + ---------- + filteredByData: pandas.DataFrame + The processed and filtered data that was used in the visualization. + tags: list + A list of column names (tags) that were plotted. + """ + from matplotlib.ticker import MaxNLocator + + filtered_by_data = data.apply(pd.to_numeric, errors='coerce') + + if remove_NaN_columns: + filtered_by_data = filtered_by_data.dropna(axis=1, how='all') + + if remove_zero_columns: + filtered_by_data = filtered_by_data.loc[:, (filtered_by_data != 0).any(axis=0)] + + tags = list(filtered_by_data.keys()) + tags = [key for key in tags if key not in not_wanted_tags] + + fig, ax = plt.subplots(len(tags), 1, figsize=(18, 10), sharex=True) + fig.suptitle(f"{title}", fontsize=16) + + timestamp_label = filtered_by_data.index + + for i, tag in enumerate(tags): + tagData = filtered_by_data[tag].values + ax[i].plot(tagData, c=color, linewidth=0.8) + ax[i].set_ylabel(tag, rotation=0, fontsize=14,labelpad=100) + ax[i].set_yticks([]) + + ax[i].spines["top"].set_visible(False) + ax[i].spines["right"].set_visible(False) + ax[i].spines["left"].set_visible(False) + + ax[i].yaxis.set_major_locator(MaxNLocator(nbins=3)) + + if i < len(tags) - 1: + ax[i].set_xticks([]) + ax[i].spines["bottom"].set_visible(False) + ax[i].xaxis.set_ticks_position('none') + else: + ax[i].set_xlabel('Time', fontsize=14) + ax[i].xaxis.set_major_locator(MaxNLocator(nbins=5)) + ax[i].set_xticklabels(timestamp_label.strftime('%Y-%m-%d %H:%M:%S'), rotation=0, ha='right') + + return filtered_by_data, tags