diff --git a/devtools/conda-envs/full_env.yaml b/devtools/conda-envs/full_env.yaml index 5ad5c940..bbe807d6 100644 --- a/devtools/conda-envs/full_env.yaml +++ b/devtools/conda-envs/full_env.yaml @@ -19,6 +19,7 @@ dependencies: - pandas - pytables - matplotlib + - plotly - mpmath #Docs diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 92c7e794..88a08745 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -19,6 +19,7 @@ dependencies: - pandas - pytables - matplotlib + - plotly # Pip-only installs diff --git a/reeds/function_libs/analysis/sampling.py b/reeds/function_libs/analysis/sampling.py index 707e0c80..c095a197 100644 --- a/reeds/function_libs/analysis/sampling.py +++ b/reeds/function_libs/analysis/sampling.py @@ -6,6 +6,7 @@ import pandas as pd import reeds.function_libs.visualization.sampling_plots +from pygromos.files.repdat import ExpandedRepdat def undersampling_occurence_potential_threshold_densityClustering(ene_trajs: List[pd.DataFrame], @@ -349,7 +350,66 @@ def sampling_analysis(ene_trajs: List[pd.DataFrame], return final_results, out_path +def analyse_state_transitions(repdat: ExpandedRepdat, min_s: int = None, normalize: bool = False, bidirectional: bool = False): + """ + Count the number of times a transition occurs between pairs of states, based on the repdat info. + Parameters + ---------- + repdat: ExpandedRepdat + ExpandedRepdat object (created from a Repdat) which contains all the exchange information of a + RE-EDS simulation plus the potential energies of the end-states + min_s: int, optional + Index of the lowest s_value to consider for the transitions. If None, consider all s values. + normalize: bool, optional + Normalize the transitions by the total number of outgoing transitions per state + bidirectional: bool, optional + Count the transitions symmetrically (state A to B together with state B to A) + Returns + ------- + np.ndarray + number of transitions between all pairs of states + """ + if normalize and bidirectional: + raise Exception("Transitions cannot be normalized w.r.t leaving state and bidirectional") + + num_replicas = len(repdat.system.s) + num_states = len(repdat.system.state_eir) + # Initialize transition counts to zero for all pairs of states + transition_counts = np.zeros((num_states, num_states)) + + for replica in range(1, num_replicas+1): + # Get exchange data per state + if min_s: + state_repdat = repdat.DATA.query(f"coord_ID == {replica} & ID <= {min_s}") + else: + state_repdat = repdat.DATA.query(f"coord_ID == {replica}") + state_trajectory = state_repdat[["Vmin", "run"]].reset_index(drop=True).copy() + + # Count the transitions between different states + for i in range(len(state_trajectory) - 1): + current_state = int("".join([char for char in state_trajectory["Vmin"][i] if char.isdigit()])) # Take the i in Vri + next_state = int("".join([char for char in state_trajectory["Vmin"][i + 1] if char.isdigit()])) + current_run = state_trajectory["run"][i] # Check that you are actually comparing consecutive exchanges + next_run = state_trajectory["run"][i+1] + if next_run == current_run +1 and current_state != next_state: + transition_counts[current_state-1][next_state-1] += 1 + + if normalize: + # Normalize by total number of transitions per state + tot_trans = np.sum(transition_counts, axis=1) + transition_counts = transition_counts / tot_trans[:, np.newaxis] + + elif bidirectional: + # Consider exchanges in both directions together + bidirectional_counts = np.zeros((num_states, num_states)) + for state1 in range(len(transition_counts)): + for state2 in range(len(transition_counts[state1])): + bidirectional_counts[state1][state2] += transition_counts[state1][state2] + bidirectional_counts[state2][state1] += transition_counts[state1][state2] + transition_counts = bidirectional_counts + + return transition_counts def detect_undersampling(ene_trajs: List[pd.DataFrame], state_potential_treshold: List[float], diff --git a/reeds/function_libs/visualization/sampling_plots.py b/reeds/function_libs/visualization/sampling_plots.py index 64a5a054..b99bb018 100644 --- a/reeds/function_libs/visualization/sampling_plots.py +++ b/reeds/function_libs/visualization/sampling_plots.py @@ -1,12 +1,16 @@ -from typing import List +from typing import Union, List import numpy as np from matplotlib import pyplot as plt +from matplotlib.colors import Colormap, to_rgba + +import plotly.graph_objects as go +from plotly.colors import convert_to_RGB_255 from reeds.function_libs.visualization import plots_style as ps from reeds.function_libs.visualization.utils import nice_s_vals -import reeds.function_libs.visualization.plots_style as ps + def plot_sampling_convergence(ene_trajs, opt_trajs, outfile, title = None, trim_beg = 0.1): """ @@ -377,3 +381,75 @@ def plot_stateOccurence_matrix(data: dict, if (not out_dir is None): fig.savefig(out_dir + '/sampling_maxContrib_matrix.png', bbox_inches='tight') plt.close() + +def plot_state_transitions(state_transitions: np.ndarray, title: str = None, colors: Union[List[str], Colormap] = ps.qualitative_tab_map, out_path: str = None): + """ + Make a Sankey plot showing the flows between states. + + Parameters + ---------- + state_transitions : np.ndarray + num_states * num_states 2D array containing the number of transitions between states + title: str, optional + printed title of the plot + colors: Union[List[str], Colormap], optional + if you don't like the default colors + out_path: str, optional + path to save the image to. if none, the image is returned as a plotly figure + Returns + ------- + None or fig + plotly figure if if was not saved + """ + num_states = len(state_transitions) + + if isinstance(colors, Colormap): + colors = [colors(i) for i in np.linspace(0, 1, num_states)] + elif len(colors) < num_states: + raise Exception("Insufficient colors to plot all states") + + def v_distribute(total_transitions): + # Vertically distribute nodes in plot based on total number of transitions per state + box_sizes = total_transitions / total_transitions.sum() + box_vplace = [np.sum(box_sizes[:i]) + box_sizes[i]/2 for i in range(len(box_sizes))] + return box_vplace + + y_placements = v_distribute(np.sum(state_transitions, axis=1)) + v_distribute(np.sum(state_transitions, axis=0)) + + # Convert colors to plotly format and make them transparent + rgba_colors = [] + for color in colors: + rgba = to_rgba(color) + rgba_plotly = convert_to_RGB_255(rgba[:-1]) + # Add opacity + rgba_plotly = rgba_plotly + (0.8,) + # Make string + rgba_colors.append("rgba" + str(rgba_plotly)) + + # Indices 0..n-1 are the source and n..2n-1 are the target. + fig = go.Figure(data=[go.Sankey( + node = dict( + pad = 5, + thickness = 20, + line = dict(color = "black", width = 2), + label = [f"state {i+1}" for i in range(num_states)]*2, + color = rgba_colors[:num_states]*2, + x = [0.1]*num_states + [1]*num_states, + y = y_placements + ), + link = dict( + arrowlen = 30, + source = np.array([[i]*num_states for i in range(num_states)]).flatten(), + target = np.array([[i for i in range(num_states, 2*num_states)] for _ in range(num_states)]).flatten(), + value = state_transitions.flatten(), + color = np.array([[c]*num_states for c in rgba_colors[:num_states]]).flatten() + ), + arrangement="fixed", + )]) + fig.update_layout(title_text=title, font_size=20, title_x=0.5, height=max(600, num_states*100)) + + if out_path: + fig.write_image(out_path) + return None + else: + return fig \ No newline at end of file diff --git a/reeds/submodules/pygromos b/reeds/submodules/pygromos index 47837858..37b32f2b 160000 --- a/reeds/submodules/pygromos +++ b/reeds/submodules/pygromos @@ -1 +1 @@ -Subproject commit 4783785811265b169f26d16e881e662a9d58316d +Subproject commit 37b32f2bb897cb1b3b0a0224b010fbf03da7c7b6