Skip to content

Commit

Permalink
cellnopt tutorial and data ploting
Browse files Browse the repository at this point in the history
  • Loading branch information
gabora committed Jun 14, 2024
1 parent 7dabb0f commit 36965f6
Show file tree
Hide file tree
Showing 2 changed files with 2,281 additions and 0 deletions.
124 changes: 124 additions & 0 deletions corneto/methods/signal/cellnopt_ilp.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,127 @@ def plot_fitness(G, exp_list, P, measured_only=False):
axs[iexp - 1, imarker].set_ylabel(f"Experiment {iexp}")

plt.show()

def collect_field_into_matrix(experiments,field_name = 'input'):
"""Collects the field_name values into matrix.
Collects the field_name values (input, inhibition etc) from a dictionary
of experiments and returns them as a numpy array.
PARAMETERS:
- experiments: dictionary of experiments containing input values
- field_name: name of the field to collect (default: 'input')
RETURNS:
- input_matrix: numpy array of input values
- input_vars: list of unique input variable names
"""
# Collect all unique input names
input_vars = set()
for exp in experiments.values():
# ensure field_name exists
if field_name not in exp:
raise ValueError('Field name (' + field_name + ') not found in experiment')
input_vars.update(exp[field_name].keys())

input_vars = sorted(input_vars) # Sorting to keep a consistent order
data = []

# Collect input values for each experiment
for exp in experiments.values():
row = [exp[field_name].get(var, 0) for var in input_vars]
data.append(row)

# Convert the data to a numpy array
input_matrix = np.array(data)

return input_matrix, input_vars

def plot_data(exp_list):
"""Plot the data.
PARAMETERS:
- exp_list: dictionary of experiments
"""
import matplotlib.pyplot as plt

N_exps = len(exp_list)

# Ensure that all experiments have the input and output variables
for exp in exp_list.values():
if 'input' not in exp:
raise ValueError('Input not found in experiment')
if 'output' not in exp:
raise ValueError('Output not found in experiment')

# Collect the input and output variables
input_matrix, input_vars = collect_field_into_matrix(exp_list, 'input')
output_matrix, output_vars = collect_field_into_matrix(exp_list, 'output')

# Check if inhibition is present in any of the experiments
inhibition_present = any('inhibition' in exp for exp in exp_list.values())
if inhibition_present:
inhibition_matrix, inhibition_vars = collect_field_into_matrix(exp_list, 'inhibition')
perturbation_matrix = np.hstack((input_matrix, inhibition_matrix))
perturbation_vars = input_vars + inhibition_vars
else:
perturbation_matrix = input_matrix
perturbation_vars = input_vars

# Create the figure
# Set colors: input colors are blue, inhibition colors are red
perturbation_colors = ['blue'] * len(input_vars)
if inhibition_present:
perturbation_colors += ['red'] * len(inhibition_vars)

fig, axs = plt.subplots(N_exps - 1, len(output_vars) + 1, squeeze=False)

fig.tight_layout(pad=0.0)
# Adjust the space between subplots
plt.subplots_adjust(wspace=0.1, hspace=0.1)

for exp, iexp in zip(exp_list, range(N_exps)):
if iexp == 0:
continue

for imarker in range(len(output_vars)):
# output_names[imarker] is the name of the output node, find the position in the graph
imarker_name = output_vars[imarker]

axs[iexp - 1, imarker].plot(
[0, 10],
[
exp_list["exp0"]["output"][imarker_name],
exp_list[exp]["output"][imarker_name],
],
"ro-",
)
axs[iexp - 1, imarker].set_ylim([-0.01, 1.1])

if iexp == 1:
axs[iexp - 1, imarker].set_title(imarker_name)
if iexp != N_exps - 1:
axs[iexp - 1, imarker].set_xticks([])
if imarker == 0:
axs[iexp - 1, imarker].set_ylabel(f"Exp. {iexp}")
else:
axs[iexp - 1, imarker].set_yticks([])

# Plot perturbation
axs[iexp - 1, len(output_vars)].bar(
range(len(perturbation_vars)), perturbation_matrix[iexp], color=perturbation_colors
)
if iexp == N_exps - 1:
axs[iexp - 1, len(output_vars)].set_xticks(range(len(perturbation_vars)))
axs[iexp - 1, len(output_vars)].set_xticklabels(perturbation_vars, rotation=45)
else:
# No xtick label
axs[iexp - 1, len(output_vars)].set_xticks([])

axs[iexp - 1, len(output_vars)].set_ylim([-0.01, 1.1])
axs[iexp - 1, len(output_vars)].set_yticks([])
if iexp == 1:
axs[iexp - 1, len(output_vars)].set_title("Pert.")
plt.show()
Loading

0 comments on commit 36965f6

Please sign in to comment.