-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Names of Modules? #103
Comments
A MWE of something close to what I want is: import torchvision
from torchview import draw_graph
import torch
import networkx as nx
def model_layers(model):
""" Extract named "leaf" layers from a module """
stack = [('', '', model)]
while stack:
prefix, basename, item = stack.pop()
name = '.'.join([p for p in [prefix, basename] if p])
if isinstance(item, torch.nn.modules.conv._ConvNd):
yield name, item
elif isinstance(item, torch.nn.modules.batchnorm._BatchNorm):
yield name, item
elif hasattr(item, 'reset_parameters'):
yield name, item
child_prefix = name
for child_basename, child_item in list(item.named_children())[::-1]:
stack.append((child_prefix, child_basename, child_item))
# Create example network
net = torchvision.models.resnet18()
model_graph = draw_graph(net, input_size=(2, 3, 224, 224), device='meta')
# Remember the dotted layer name associated with each torch.Module
# instance. Usually a module will just have one name associated to an
# instance, but it could have more than one.
from collections import defaultdict
named_layers = list(model_layers(net))
id_to_names = defaultdict(list)
for name, layer in named_layers:
layer_id = id(layer)
id_to_names[layer_id].append(name)
def make_label(n, data):
""" Create a nice printable label """
n_id = id(n)
n_id_str = str(n_id)
parts = []
if 'layer_name' in data:
parts.append(data['layer_name'] + ':')
parts.append(n.name)
if n_id_str in model_graph.id_dict:
idx = model_graph.id_dict[n_id_str]
parts.append(f':{idx}')
if n_id in id_to_names:
parts.append(' ' + id_to_names[n_id])
label = ''.join(parts)
return label
# Build a networkx version of the torchview model graph
graph = nx.DiGraph()
for node in model_graph.node_set:
graph.add_node(node)
for u, v in model_graph.edge_list:
u_id = id(u)
v_id = id(v)
graph.add_edge(u_id, v_id)
graph.nodes[u_id]['compute_node'] = u
graph.nodes[v_id]['compute_node'] = v
# Enrich each node with more info
for n_id, data in graph.nodes(data=True):
if 'compute_node' in data:
n = data['compute_node']
if hasattr(n, 'compute_unit_id'):
if n.compute_unit_id in id_to_names:
layer_names = id_to_names[n.compute_unit_id]
if len(layer_names) == 1:
data['layer_name'] = layer_names[0]
else:
data['layer_names'] = layer_names[0]
data['label'] = make_label(n, data)
nx.write_network_text(graph, vertical_chains=1)
# model_graph.visual_graph.view() Produces:
You can see here that I've been able to associate many of the nodes with their original layer names. However, my solution to this is to just assume each instance is only used once. I think a correct solution would attempt to know which module attribute was the caller - which gets tricky if you assign multiple instances of a module to different variables. To further process this into what I'm actually intersted in I do something like this: # Now that we have a graph where a subset of nodes correspond to known
# layers, we can postprocess it to only show effective connections between
# the layers.
# Determine which nodes have associated layer names
remove_ids = []
keep_ids = []
for n_id, data in graph.nodes(data=True):
if 'layer_name' in data:
keep_ids.append(n_id)
else:
remove_ids.append(n_id)
import ubelt as ub
topo_order = ub.OrderedSet(nx.topological_sort(graph))
keep_topo_order = (topo_order & keep_ids)
# Find the nearest ancestor that we want to view and collapse the node we
# dont care about into it. Do a final relabeling to keep the original node
# ids where possible.
collapseables = defaultdict(list)
for n in remove_ids:
valid_prev_nodes = keep_topo_order & set(nx.ancestors(graph, n))
if valid_prev_nodes:
p = valid_prev_nodes[-1]
collapseables[p].append(n)
from networkx.algorithms.connectivity.edge_augmentation import collapse
grouped_nodes = []
for p, vs in collapseables.items():
grouped_nodes.append([p, *vs])
g2 = collapse(graph, grouped_nodes)
relabel = {n: n for n in g2.nodes}
new_to_olds = ub.udict(g2.graph['mapping']).invert(unique_vals=0)
for new, olds in new_to_olds.items():
if len(olds) == 1:
old = ub.peek(olds)
relabel[new] = old
else:
keep_olds = keep_topo_order & olds
old = ub.peek(keep_olds)
relabel[new] = old
g3 = nx.relabel_nodes(g2, relabel)
def transfer_data(g_dst, g_src):
for n in set(g_dst.nodes) & set(g_src.nodes):
g_dst.nodes[n].update(g_src.nodes[n])
# Show the collapsed graph
transfer_data(g3, graph)
nx.write_network_text(g3, vertical_chains=1)
# Further reduce the graph to remove skip connection information
g4 = nx.transitive_reduction(g3)
transfer_data(g4, graph)
nx.write_network_text(g4, vertical_chains=1)
g2 = nx.transitive_closure(graph)
g2 = nx.transitive_reduction(g2)
transfer_data(g2, graph) Which shows the graph where the intermediate functional nodes have been collapsed into one of their parent layers: ╟── auxiliary-tensor
╎ ╽
╎ conv1:Conv2d:1
╎ ╽
╎ bn1:BatchNorm2d:2
╎ ├─╼ layer1.0.conv1:Conv2d:5
╎ │ ╽
╎ │ layer1.0.bn1:BatchNorm2d:6
╎ │ ╽
╎ │ layer1.0.conv2:Conv2d:8
╎ │ ╽
╎ │ layer1.0.bn2:BatchNorm2d:9 ╾ bn1:BatchNorm2d:2
╎ │ ├─╼ layer1.1.conv1:Conv2d:12
╎ │ │ ╽
╎ │ │ layer1.1.bn1:BatchNorm2d:13
╎ │ │ ╽
╎ │ │ layer1.1.conv2:Conv2d:15
╎ │ │ ╽
╎ │ │ layer1.1.bn2:BatchNorm2d:16 ╾ layer1.0.bn2:BatchNorm2d:9
╎ │ │ ├─╼ layer2.0.bn2:BatchNorm2d:23 ╾ layer2.0.conv2:Conv2d:22
╎ │ │ │ ├─╼ layer2.1.conv1:Conv2d:27
╎ │ │ │ │ ╽
╎ │ │ │ │ layer2.1.bn1:BatchNorm2d:28
╎ │ │ │ │ ╽
╎ │ │ │ │ layer2.1.conv2:Conv2d:30
╎ │ │ │ │ ╽
╎ │ │ │ │ layer2.1.bn2:BatchNorm2d:31 ╾ layer2.0.bn2:BatchNorm2d:23
╎ │ │ │ │ ├─╼ layer3.0.bn2:BatchNorm2d:38 ╾ layer3.0.conv2:Conv2d:37
╎ │ │ │ │ │ ├─╼ layer3.1.conv1:Conv2d:42
╎ │ │ │ │ │ │ ╽
╎ │ │ │ │ │ │ layer3.1.bn1:BatchNorm2d:43
╎ │ │ │ │ │ │ ╽
╎ │ │ │ │ │ │ layer3.1.conv2:Conv2d:45
╎ │ │ │ │ │ │ ╽
╎ │ │ │ │ │ │ layer3.1.bn2:BatchNorm2d:46 ╾ layer3.0.bn2:BatchNorm2d:38
╎ │ │ │ │ │ │ ├─╼ layer4.0.bn2:BatchNorm2d:53 ╾ layer4.0.conv2:Conv2d:52
╎ │ │ │ │ │ │ │ ├─╼ layer4.1.conv1:Conv2d:57
╎ │ │ │ │ │ │ │ │ ╽
╎ │ │ │ │ │ │ │ │ layer4.1.bn1:BatchNorm2d:58
╎ │ │ │ │ │ │ │ │ ╽
╎ │ │ │ │ │ │ │ │ layer4.1.conv2:Conv2d:60
╎ │ │ │ │ │ │ │ │ ╽
╎ │ │ │ │ │ │ │ │ layer4.1.bn2:BatchNorm2d:61 ╾ layer4.0.bn2:BatchNorm2d:53
╎ │ │ │ │ │ │ │ │ ╽
╎ │ │ │ │ │ │ │ │ fc:Linear:66
╎ │ │ │ │ │ │ │ └─╼ ...
╎ │ │ │ │ │ │ └─╼ layer4.0.conv1:Conv2d:49
╎ │ │ │ │ │ │ ╽
╎ │ │ │ │ │ │ layer4.0.bn1:BatchNorm2d:50
╎ │ │ │ │ │ │ ╽
╎ │ │ │ │ │ │ layer4.0.conv2:Conv2d:52
╎ │ │ │ │ │ │ └─╼ ...
╎ │ │ │ │ │ └─╼ ...
╎ │ │ │ │ └─╼ layer3.0.conv1:Conv2d:34
╎ │ │ │ │ ╽
╎ │ │ │ │ layer3.0.bn1:BatchNorm2d:35
╎ │ │ │ │ ╽
╎ │ │ │ │ layer3.0.conv2:Conv2d:37
╎ │ │ │ │ └─╼ ...
╎ │ │ │ └─╼ ...
╎ │ │ └─╼ layer2.0.conv1:Conv2d:19
╎ │ │ ╽
╎ │ │ layer2.0.bn1:BatchNorm2d:20
╎ │ │ ╽
╎ │ │ layer2.0.conv2:Conv2d:22
╎ │ │ └─╼ ...
╎ │ └─╼ ...
╎ └─╼ ...
╙── 139679377001936 and finally what I ultimately want to see: the transitive reduction of this graph:
|
Is your feature request related to a problem? Please describe.
I'm looking for a way to visualize the information flow of a network in terms of nested module names. I.e. if we extract torch modules of specific types like this:
which for torchvision.resenet18 looks like:
I would like to be able to determine what layers conceptually connected (i.e. the outputs of one layer eventually make it to the inputs of another layer).
Describe the solution you'd like
Currently when I do something like:
I don't see any recorded layer names. I'm wondering if it's possible in collect graph (https://github.com/mert-kurttutan/torchview/blob/main/torchview/computation_graph.py#L188) to also associate layer names with nodes as they are extracted.
Describe alternatives you've considered
Tried doing this myself. There are a lot of corner cases. This library seems like it has the best approach for torch graph extaction I've seen so far.
Additional context
I raised a similar issue on the torch discussion page: https://discuss.pytorch.org/t/tracing-a-graph-of-torch-layers/187615
The text was updated successfully, but these errors were encountered: