Skip to content

Commit

Permalink
Implement a draw method in scenario config for debug purpose. (#2289)
Browse files Browse the repository at this point in the history
* scenario_config.draw

* minor changes

* wrong import

* fix f-string

* minor formatting

* Update config.pyi

* Apply suggestions from code review

Co-authored-by: Đỗ Trường Giang <[email protected]>

---------

Co-authored-by: jrobinAV <[email protected]>
Co-authored-by: Đỗ Trường Giang <[email protected]>
  • Loading branch information
3 people authored Nov 29, 2024
1 parent 45cb211 commit fcd6497
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 5 deletions.
4 changes: 2 additions & 2 deletions taipy/common/config/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ class Config:
corresponds to the data node configuration id. During the scenarios'
comparison, each comparator is applied to all the data nodes instantiated from
the data node configuration attached to the comparator. See
`(taipy.)compare_scenarios()^` more more details.
`(taipy.)compare_scenarios()^` more details.
sequences (Optional[Dict[str, List[TaskConfig]]]): Dictionary of sequence descriptions.
The default value is None.
**properties (dict[str, any]): A keyworded variable length list of additional arguments.
Expand Down Expand Up @@ -321,7 +321,7 @@ class Config:
corresponds to the data node configuration id. During the scenarios'
comparison, each comparator is applied to all the data nodes instantiated from
the data node configuration attached to the comparator. See
`taipy.compare_scenarios()^` more more details.
`taipy.compare_scenarios()^` more details.
sequences (Optional[Dict[str, List[TaskConfig]]]): Dictionary of sequences. The default value is None.
**properties (dict[str, any]): A keyworded variable length list of additional arguments.
Expand Down
66 changes: 63 additions & 3 deletions taipy/core/config/scenario_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class ScenarioConfig(Section):
_TASKS_KEY = "tasks"
_ADDITIONAL_DATA_NODES_KEY = "additional_data_nodes"
_FREQUENCY_KEY = "frequency"
_SEQUENCES_KEY = "sequences"
_COMPARATOR_KEY = "comparators"

frequency: Optional[Frequency]
Expand Down Expand Up @@ -305,7 +304,7 @@ def _configure(
corresponds to the data node configuration id. During the scenarios'
comparison, each comparator is applied to all the data nodes instantiated from
the data node configuration attached to the comparator. See
`(taipy.)compare_scenarios()^` more more details.
`(taipy.)compare_scenarios()^` more details.
sequences (Optional[Dict[str, List[TaskConfig]]]): Dictionary of sequence descriptions.
The default value is None.
**properties (dict[str, any]): A keyworded variable length list of additional arguments.
Expand Down Expand Up @@ -355,7 +354,7 @@ def _set_default_configuration(
corresponds to the data node configuration id. During the scenarios'
comparison, each comparator is applied to all the data nodes instantiated from
the data node configuration attached to the comparator. See
`taipy.compare_scenarios()^` more more details.
`taipy.compare_scenarios()^` more details.
sequences (Optional[Dict[str, List[TaskConfig]]]): Dictionary of sequences. The default value is None.
**properties (dict[str, any]): A keyworded variable length list of additional arguments.
Expand All @@ -373,3 +372,64 @@ def _set_default_configuration(
)
Config._register(section)
return Config.sections[ScenarioConfig.name][_Config.DEFAULT_KEY]

def draw(self, file_path: Optional[str]=None) -> None:
"""
Export the scenario configuration graph as a PNG file.
This function uses the `matplotlib` library to draw the scenario configuration graph.
`matplotlib` must be installed independently of `taipy` as it is not a dependency.
If `matplotlib` is not installed, the function will log an error message, and do nothing.
Arguments:
file_path (Optional[str]): The path to save the PNG file.
If not provided, the file will be saved with the scenario configuration id.
"""
from importlib import util

from taipy.common.logger._taipy_logger import _TaipyLogger
logger = _TaipyLogger._get_logger()

if not util.find_spec("matplotlib"):
logger.error("Cannot draw the scenario configuration as `matplotlib` is not installed.")
return
import matplotlib.pyplot as plt
import networkx as nx

from taipy.core._entity._dag import _DAG

def build_dag() -> nx.DiGraph:
g = nx.DiGraph()
for task in set(self.tasks):
if has_input := task.inputs:
for predecessor in task.inputs:
g.add_edges_from([(predecessor, task)])
if has_output := task.outputs:
for successor in task.outputs:
g.add_edges_from([(task, successor)])
if not has_input and not has_output:
g.add_node(task)
return g
graph = build_dag()
dag = _DAG(graph)
pos = {node.entity: (node.x, node.y) for node in dag.nodes.values()}
labls = {node.entity: node.entity.id for node in dag.nodes.values()}

# Draw the graph
plt.figure(figsize=(10, 10))
nx.draw_networkx_nodes(graph, pos,
nodelist=[node for node in graph.nodes if isinstance(node, DataNodeConfig)],
node_color="skyblue",
node_shape="s",
node_size=2000)
nx.draw_networkx_nodes(graph, pos,
nodelist=[node for node in graph.nodes if isinstance(node, TaskConfig)],
node_color="orange",
node_shape="D",
node_size=2000)
nx.draw_networkx_labels(graph, pos, labels=labls)
nx.draw_networkx_edges(graph, pos, node_size=2000, edge_color="black", arrowstyle="->", arrowsize=25)
path = file_path or f"{self.id}.png"
plt.savefig(path)
plt.close() # Close the plot to avoid display
logger.info(f"The graph image of the scenario configuration `{self.id}` is exported: {path}")
79 changes: 79 additions & 0 deletions tests/core/config/test_scenario_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import os
from unittest import mock

import pytest

from taipy.common.config import Config
from taipy.common.config.common.frequency import Frequency
from tests.core.utils.named_temporary_file import NamedTemporaryFile
Expand Down Expand Up @@ -299,3 +301,80 @@ def test_add_sequence():
assert len(scenario_config.sequences) == 2
scenario_config.remove_sequences(["sequence2", "sequence3"])
assert len(scenario_config.sequences) == 0

@pytest.mark.skip(reason="Generates a png that must be visually verified.")
def test_draw_1():
dn_config_1 = Config.configure_data_node("dn1")
dn_config_2 = Config.configure_data_node("dn2")
dn_config_3 = Config.configure_data_node("dn3")
dn_config_4 = Config.configure_data_node("dn4")
dn_config_5 = Config.configure_data_node("dn5")
task_config_1 = Config.configure_task("task1", sum, input=[dn_config_1, dn_config_2], output=dn_config_3)
task_config_2 = Config.configure_task("task2", sum, input=[dn_config_1, dn_config_3], output=dn_config_4)
task_config_3 = Config.configure_task("task3", print, input=dn_config_4)
scenario_cfg = Config.configure_scenario(
"scenario1",
[task_config_1, task_config_2, task_config_3],
[dn_config_5],
)
scenario_cfg.draw()

@pytest.mark.skip(reason="Generates a png that must be visually verified.")
def test_draw_2():
data_node_1 = Config.configure_data_node("s1")
data_node_2 = Config.configure_data_node("s2")
data_node_4 = Config.configure_data_node("s4")
data_node_5 = Config.configure_data_node("s5")
data_node_6 = Config.configure_data_node("s6")
data_node_7 = Config.configure_data_node("s7")
task_1 = Config.configure_task("t1", print, [data_node_1, data_node_2], [data_node_4])
task_2 = Config.configure_task("t2", print, None, [data_node_5])
task_3 = Config.configure_task("t3", print, [data_node_5, data_node_4], [data_node_6])
task_4 = Config.configure_task("t4", print, [data_node_4], [data_node_7])
scenario_cfg = Config.configure_scenario("scenario1", [task_4, task_2, task_1, task_3])

# 6 | t2 _____
# 5 | \
# 4 | s5 _________________ t3 _______ s6
# 3 | s1 __ _ s4 _____/
# 2 | \ _ t1 ____/ \_ t4 _______ s7
# 1 | /
# 0 | s2 --
# |________________________________________________
# 0 1 2 3 4
scenario_cfg.draw("draw_2")

@pytest.mark.skip(reason="Generates a png that must be visually verified.")
def test_draw_3():
data_node_1 = Config.configure_data_node("s1")
data_node_2 = Config.configure_data_node("s2")
data_node_3 = Config.configure_data_node("s3")
data_node_4 = Config.configure_data_node("s4")
data_node_5 = Config.configure_data_node("s5")
data_node_6 = Config.configure_data_node("s6")
data_node_7 = Config.configure_data_node("s7")

task_1 = Config.configure_task("t1", print, [data_node_1, data_node_2, data_node_3], [data_node_4])
task_2 = Config.configure_task("t2", print, [data_node_4], None)
task_3 = Config.configure_task("t3", print, [data_node_4], [data_node_5])
task_4 = Config.configure_task("t4", print, None, output=[data_node_6])
task_5 = Config.configure_task("t5", print, [data_node_7], None)
scenario_cfg = Config.configure_scenario("scenario1", [task_5, task_3, task_4, task_2, task_1])


# 12 | s7 __
# 11 | \
# 10 | \
# 9 | t4 _ \_ t5
# 8 | \ ____ t3 ___
# 7 | \ / \
# 6 | s3 _ \__ s6 _ s4 _/ \___ s5
# 5 | \ / \
# 4 | \ / \____ t2
# 3 | s2 ___\__ t1 __/
# 2 | /
# 1 | /
# 0 | s1 _/
# |________________________________________________
# 0 1 2 3 4
scenario_cfg.draw("draw_3")

0 comments on commit fcd6497

Please sign in to comment.