Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made the colors of the graph parametrizable #115

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/test_output/change_color_scheme.gv
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
strict digraph ColorSchemeChanged {
graph [ordering=in rankdir=TB size="12.0,12.0"]
node [align=left fontname="Linux libertine" fontsize=10 height=0.2 margin=0 ranksep=0.1 shape=plaintext style=filled]
edge [fontsize=10]
0 [label=<
<TABLE BORDER="0" CELLBORDER="1"
CELLSPACING="0" CELLPADDING="4">
<TR><TD>input-tensor<BR/>depth:0</TD><TD>(1, 32)</TD></TR>
</TABLE>> fillcolor=aquamarine1]
1 [label=<
<TABLE BORDER="0" CELLBORDER="1"
CELLSPACING="0" CELLPADDING="4">
<TR>
<TD ROWSPAN="2">Linear<BR/>depth:1</TD>
<TD COLSPAN="2">input:</TD>
<TD COLSPAN="2">(1, 32) </TD>
</TR>
<TR>
<TD COLSPAN="2">output: </TD>
<TD COLSPAN="2">(1, 2) </TD>
</TR>
</TABLE>> fillcolor=deepskyblue1]
2 [label=<
<TABLE BORDER="0" CELLBORDER="1"
CELLSPACING="0" CELLPADDING="4">
<TR>
<TD ROWSPAN="2">add<BR/>depth:1</TD>
<TD COLSPAN="2">input:</TD>
<TD COLSPAN="2">(1, 2) </TD>
</TR>
<TR>
<TD COLSPAN="2">output: </TD>
<TD COLSPAN="2">(1, 2) </TD>
</TR>
</TABLE>> fillcolor=turquoise1]
3 [label=<
<TABLE BORDER="0" CELLBORDER="1"
CELLSPACING="0" CELLPADDING="4">
<TR><TD>output-tensor<BR/>depth:0</TD><TD>(1, 2)</TD></TR>
</TABLE>> fillcolor=aquamarine1]
0 -> 1
1 -> 2
2 -> 3
}
Binary file added tests/test_output/change_color_scheme.gv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 20 additions & 0 deletions tests/test_torchview.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,23 @@ def test_isolated_tensor(verify_result: Callable[..., Any]) -> None:
)

verify_result([model_graph])


def test_change_color_scheme(verify_result: Callable[..., Any]) -> None:
input_data = [
torch.rand(1, 32),
]
model = IsolatedTensor()
model_graph = draw_graph(
model,
input_data=input_data,
expand_nested=True,
depth=3,
graph_name='ColorSchemeChanged',
colors={'TensorNode': "aquamarine1",
'ModuleNode': "deepskyblue1",
'FunctionNode': "turquoise1"},
device='cpu',
)

verify_result([model_graph])
25 changes: 13 additions & 12 deletions torchview/computation_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@

COMPUTATION_NODES = Union[TensorNode, ModuleNode, FunctionNode]

node2color = {
TensorNode: "lightyellow",
ModuleNode: "darkseagreen1",
FunctionNode: "aliceblue",
}

# TODO: Currently, we only use directed graphviz graph since DNN are
# graphs except for e.g. graph neural network (GNN). Experiment on GNN
# and see if undirected graphviz graph can be used to represent GNNs
Expand Down Expand Up @@ -62,6 +56,7 @@ def __init__(
visual_graph: Digraph,
root_container: NodeContainer[TensorNode],
show_shapes: bool = True,
colors: dict | None = None,
expand_nested: bool = False,
hide_inner_tensors: bool = True,
hide_module_functions: bool = True,
Expand Down Expand Up @@ -90,6 +85,15 @@ def __init__(
'col_span': 2,
'row_span': 2,
}

self.node2color = {
'TensorNode': 'lightyellow',
'ModuleNode': 'darkseagreen1',
'FunctionNode': 'aliceblue',
}
if colors:
self.node2color.update(colors)

self.reset_graph_history()

def reset_graph_history(self) -> None:
Expand Down Expand Up @@ -356,7 +360,7 @@ def add_node(
self.id_dict[node.node_id] = self.running_node_id
self.running_node_id += 1
label = self.get_node_label(node)
node_color = ComputationGraph.get_node_color(node)
node_color = self.get_node_color(node)

if subgraph is None:
subgraph = self.visual_graph
Expand Down Expand Up @@ -423,11 +427,8 @@ def resize_graph(
size_str = str(size) + "," + str(size)
self.visual_graph.graph_attr.update(size=size_str,)

@staticmethod
def get_node_color(
node: COMPUTATION_NODES
) -> str:
return node2color[type(node)]
def get_node_color(self, node: COMPUTATION_NODES) -> str:
return self.node2color[node.__class__.__name__]

def check_node(self, node: COMPUTATION_NODES) -> None:
assert node.node_id != 'null', f'wrong id {node} {type(node)}'
Expand Down
2 changes: 1 addition & 1 deletion torchview/computation_node/compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class TensorNode(Node):
'''Subclass of node specialzed for nodes that
'''Subclass of node specialized for nodes that
stores tensor (subclass of torch.Tensor called RecorderTensor)
'''
def __init__(
Expand Down
10 changes: 9 additions & 1 deletion torchview/torchview.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def draw_graph(
hide_inner_tensors: bool = True,
roll: bool = False,
show_shapes: bool = True,
colors: dict | None = None,
save_graph: bool = False,
filename: str | None = None,
directory: str = '.',
Expand Down Expand Up @@ -145,6 +146,13 @@ def draw_graph(
False => Dont show
Default: True

colors (dict):
Color scheme used for the plotting. The dictionary can
contain a graphiz color for "TensorNode", "ModuleNode"
and "FunctionNode". If not given, the default colors will
be used. Example: colors = {TensorNode: "lightyellow"}
Default: default color scheme

save_graph (bool):
True => Saves output file of graphviz graph
False => Does not save
Expand Down Expand Up @@ -213,7 +221,7 @@ def draw_graph(
)

model_graph = ComputationGraph(
visual_graph, input_nodes, show_shapes, expand_nested,
visual_graph, input_nodes, show_shapes, colors, expand_nested,
hide_inner_tensors, hide_module_functions, roll, depth
)

Expand Down