Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add plots to show the change in maxContrib state #96

Merged
merged 8 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions devtools/conda-envs/full_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- pandas
- pytables
- matplotlib
- plotly
- mpmath

#Docs
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- pandas
- pytables
- matplotlib
- plotly


# Pip-only installs
Expand Down
60 changes: 60 additions & 0 deletions reeds/function_libs/analysis/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd

import reeds.function_libs.visualization.sampling_plots
from pygromos.files.repdat import ExpandedRepdat


def undersampling_occurence_potential_threshold_densityClustering(ene_trajs: List[pd.DataFrame],
Expand Down Expand Up @@ -349,7 +350,66 @@ def sampling_analysis(ene_trajs: List[pd.DataFrame],

return final_results, out_path

def analyse_state_transitions(repdat: ExpandedRepdat, min_s: int = None, normalize: bool = False, bidirectional: bool = False):
"""
Count the number of times a transition occurs between pairs of states, based on the repdat info.
Parameters
----------
repdat: ExpandedRepdat
ExpandedRepdat object (created from a Repdat) which contains all the exchange information of a
RE-EDS simulation plus the potential energies of the end-states
min_s: int, optional
Index of the lowest s_value to consider for the transitions. If None, consider all s values.
normalize: bool, optional
Normalize the transitions by the total number of outgoing transitions per state
bidirectional: bool, optional
Count the transitions symmetrically (state A to B together with state B to A)
Returns
-------
np.ndarray
number of transitions between all pairs of states
"""
if normalize and bidirectional:
raise Exception("Transitions cannot be normalized w.r.t leaving state and bidirectional")

num_replicas = len(repdat.system.s)
num_states = len(repdat.system.state_eir)

# Initialize transition counts to zero for all pairs of states
transition_counts = np.zeros((num_states, num_states))

for replica in range(1, num_replicas+1):
# Get exchange data per state
if min_s:
state_repdat = repdat.DATA.query(f"coord_ID == {replica} & ID <= {min_s}")
else:
state_repdat = repdat.DATA.query(f"coord_ID == {replica}")
state_trajectory = state_repdat[["Vmin", "run"]].reset_index(drop=True).copy()

# Count the transitions between different states
for i in range(len(state_trajectory) - 1):
current_state = int("".join([char for char in state_trajectory["Vmin"][i] if char.isdigit()])) # Take the i in Vri
next_state = int("".join([char for char in state_trajectory["Vmin"][i + 1] if char.isdigit()]))
current_run = state_trajectory["run"][i] # Check that you are actually comparing consecutive exchanges
next_run = state_trajectory["run"][i+1]
if next_run == current_run +1 and current_state != next_state:
transition_counts[current_state-1][next_state-1] += 1

if normalize:
# Normalize by total number of transitions per state
tot_trans = np.sum(transition_counts, axis=1)
transition_counts = transition_counts / tot_trans[:, np.newaxis]

elif bidirectional:
# Consider exchanges in both directions together
bidirectional_counts = np.zeros((num_states, num_states))
for state1 in range(len(transition_counts)):
for state2 in range(len(transition_counts[state1])):
bidirectional_counts[state1][state2] += transition_counts[state1][state2]
bidirectional_counts[state2][state1] += transition_counts[state1][state2]
transition_counts = bidirectional_counts

return transition_counts

def detect_undersampling(ene_trajs: List[pd.DataFrame],
state_potential_treshold: List[float],
Expand Down
80 changes: 78 additions & 2 deletions reeds/function_libs/visualization/sampling_plots.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import List
from typing import Union, List

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import Colormap, to_rgba

import plotly.graph_objects as go
from plotly.colors import convert_to_RGB_255

from reeds.function_libs.visualization import plots_style as ps
from reeds.function_libs.visualization.utils import nice_s_vals

import reeds.function_libs.visualization.plots_style as ps


def plot_sampling_convergence(ene_trajs, opt_trajs, outfile, title = None, trim_beg = 0.1):
"""
Expand Down Expand Up @@ -377,3 +381,75 @@ def plot_stateOccurence_matrix(data: dict,
if (not out_dir is None):
fig.savefig(out_dir + '/sampling_maxContrib_matrix.png', bbox_inches='tight')
plt.close()

def plot_state_transitions(state_transitions: np.ndarray, title: str = None, colors: Union[List[str], Colormap] = ps.qualitative_tab_map, out_path: str = None):
"""
Make a Sankey plot showing the flows between states.

Parameters
----------
state_transitions : np.ndarray
num_states * num_states 2D array containing the number of transitions between states
title: str, optional
printed title of the plot
colors: Union[List[str], Colormap], optional
if you don't like the default colors
out_path: str, optional
path to save the image to. if none, the image is returned as a plotly figure
Returns
-------
None or fig
plotly figure if if was not saved
"""
num_states = len(state_transitions)

if isinstance(colors, Colormap):
colors = [colors(i) for i in np.linspace(0, 1, num_states)]
elif len(colors) < num_states:
raise Exception("Insufficient colors to plot all states")

def v_distribute(total_transitions):
# Vertically distribute nodes in plot based on total number of transitions per state
box_sizes = total_transitions / total_transitions.sum()
box_vplace = [np.sum(box_sizes[:i]) + box_sizes[i]/2 for i in range(len(box_sizes))]
return box_vplace

y_placements = v_distribute(np.sum(state_transitions, axis=1)) + v_distribute(np.sum(state_transitions, axis=0))

# Convert colors to plotly format and make them transparent
rgba_colors = []
for color in colors:
rgba = to_rgba(color)
rgba_plotly = convert_to_RGB_255(rgba[:-1])
# Add opacity
rgba_plotly = rgba_plotly + (0.8,)
# Make string
rgba_colors.append("rgba" + str(rgba_plotly))

# Indices 0..n-1 are the source and n..2n-1 are the target.
fig = go.Figure(data=[go.Sankey(
node = dict(
pad = 5,
thickness = 20,
line = dict(color = "black", width = 2),
label = [f"state {i+1}" for i in range(num_states)]*2,
color = rgba_colors[:num_states]*2,
x = [0.1]*num_states + [1]*num_states,
y = y_placements
),
link = dict(
arrowlen = 30,
source = np.array([[i]*num_states for i in range(num_states)]).flatten(),
target = np.array([[i for i in range(num_states, 2*num_states)] for _ in range(num_states)]).flatten(),
value = state_transitions.flatten(),
color = np.array([[c]*num_states for c in rgba_colors[:num_states]]).flatten()
),
arrangement="fixed",
)])
fig.update_layout(title_text=title, font_size=20, title_x=0.5, height=max(600, num_states*100))

if out_path:
fig.write_image(out_path)
return None
else:
return fig
2 changes: 1 addition & 1 deletion reeds/submodules/pygromos