Skip to content

Commit

Permalink
add tensor info(#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiayouran committed Oct 14, 2022
1 parent 8aab27d commit 1ae0563
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 22 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ export PYTHONPATH=$PYTHONPATH:${your-path}/VisuTVM
# visu relay ir(default: FuseOps)
python main.py -bp relay_ir/example_fo_bp.txt -ap relay_ir/example_fo_ap.txt -sn example

# visu relay ir with tensor info
python main.py -bp relay_ir/example_fo_bp.txt -ap relay_ir/example_fo_ap.txt -sn example -wi

# create relay ir txt file
python examples/example.py --passname FuseOps
```
Expand Down
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
help='relay ir after pass txt file')
parser.add_argument('--save_name', '-sn', type=str, default='example',
help='png save name')
parser.add_argument('--with_info', '-wi', action='store_true', default=False,
help='png save name')
args = parser.parse_args()


if __name__ == '__main__':
save_name = args.save_name
before_pass = args.before_pass
after_pass = args.after_pass
with_info = args.with_info

visu_relay_ir(before_pass, after_pass, save_name)
visu_relay_ir(before_pass, after_pass, save_name, with_info)
print('Finshed!')
16 changes: 8 additions & 8 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,34 @@ def relay_ir2txt(context, file_name='example', is_ap=False):
f.writelines(str(context))


def visu_relay_ir(bp_file, ap_file, save_name):
g = VisuGraph(txt_file=bp_file, save_name=save_name)
def visu_relay_ir(bp_file, ap_file, save_name, with_info=False):
g = VisuGraph(txt_file=bp_file, save_name=save_name, with_info=with_info)
g.codegen()

if '_fo_' in ap_file:
g = VisuGraphFuseOps(txt_file=ap_file, save_name=save_name)
g = VisuGraphFuseOps(txt_file=ap_file, save_name=save_name, with_info=with_info)
elif '_ruf_' in ap_file or '_fc_' in ap_file or '_ecs_' in ap_file or '_si_' in ap_file or '_fm_' in ap_file or \
'_se_' in ap_file or '_fac_' in ap_file or '_cc_' in ap_file or '_cl_' in ap_file or '_fsa_' in ap_file or \
'_cpc2d_' in ap_file or '_cpd_' in ap_file or '_cpbm_' in ap_file:
g = VisuGraphRUF(txt_file=ap_file, save_name=save_name)
g = VisuGraphRUF(txt_file=ap_file, save_name=save_name, with_info=with_info)
elif '_mc_' in ap_file:
g = VisuGraphMC(txt_file=ap_file, save_name=save_name)
g = VisuGraphMC(txt_file=ap_file, save_name=save_name, with_info=with_info)
else:
warnings.warn("not support the pass to visu now! ==> {}".format(ap_file))
# TODO 由于没有合适的case,部分Pass优化后的Relay IR可视化可能会失败
# 有些Pass在优化神经网络(目前只在resnet18上进行了测试)的时候可能不起作用,因此Pass优化前后的可视化结果是一样的
g = VisuGraphRUF(txt_file=ap_file, save_name=save_name)
g = VisuGraphRUF(txt_file=ap_file, save_name=save_name, with_info=with_info)
g.codegen()


def run_all_examples(scan_dir='relay_ir'):
def run_all_examples(scan_dir='relay_ir', with_info=False):
bp_list = glob.glob(os.path.join(scan_dir, '*_bp.txt'))
for bp_file in bp_list:
ap_file = bp_file.replace('_bp', '_ap')
save_name = bp_file.replace('.txt', '')

print("Parsing {} and {}".format(bp_file, ap_file))
visu_relay_ir(bp_file, ap_file, save_name)
visu_relay_ir(bp_file, ap_file, save_name, with_info)


def _get_positive_scale(size):
Expand Down
131 changes: 118 additions & 13 deletions visu_tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ def __init__(self, name='None', label='None', color='', style='', inputs=None) -


class IREdge(object):
def __init__(self, tail_name, head_name) -> None:
def __init__(self, tail_name, head_name, shape='', dtype='') -> None:
self.tail_name = tail_name
self.head_name = head_name
self.shape = shape
self.dtype = dtype


class VisuGraph(object):
"""Visu TVM Relay IR"""
def __init__(self, txt_file, save_name='example') -> None:
def __init__(self, txt_file, save_name='example', with_info=False) -> None:
self.graph_code = ''
self.nodes = dict()
self.edges = list()
Expand All @@ -50,6 +52,8 @@ def __init__(self, txt_file, save_name='example') -> None:
self.save_name = 'output/visu_{}_relay_ir'.format(save_name)
self.txt_file = txt_file
self.node_map = dict()
self.tensor_info = dict()
self.with_info = with_info

def random_color(self):
colors = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f']
Expand All @@ -74,7 +78,7 @@ def parse_txt(self, txt_file=''):
if match_op:
self.parse_res.extend(match_op)
else:
self.parse_res.append(('', line))
self.parse_res.append(('}', line))

def init_node(self):
# graph.node(name='%input0', label='%input0', color='white', style='filled')
Expand All @@ -87,11 +91,15 @@ def init_edge(self):
# graph.edge(tail_name='%input0', head_name='%0', label='')
# base_eage = "graph.edge(tail_name='{}', head_name='{}')\n"
for edge in self.edges:
self.edge_code += "graph.edge(tail_name='{}', head_name='{}')\n".format(edge.tail_name, edge.head_name)
self.edge_code += "graph.edge(tail_name='{}', head_name='{}', label='{}')\n".format(edge.tail_name,
edge.head_name,
edge.shape)

def codegen(self):
from graphviz import Digraph

if '_bp' not in self.txt_file and self.with_info:
self.get_tensor_info(self.txt_file)
self.parse_txt(self.txt_file)
self.parse_node()
self.parse_edge()
Expand Down Expand Up @@ -129,7 +137,6 @@ def get_node_args(self, output_node, body_node):
args_list = [self.node_map.get(arg, arg) for arg in args_list]
else:
args_list = re.findall(pattern1, body_node)
# args_list = [self.node_map.get(arg[0], arg[0]) for arg in args_list]
args_list = [self.node_map.get(arg[0], arg[0]) if arg[0] else arg[1] for arg in args_list]

if args_list and isinstance(args_list[0], list):
Expand All @@ -152,21 +159,113 @@ def parse_node(self):
if not self.nodes.get(n, ''):
self.nodes[n] = IRNode(name=n, label=n, color='white')

# 在图的末尾添加一个空节点
self.nodes[''] = IRNode(name='', label='', color='white')

def parse_edge(self):
for k, v in self.nodes.items():
if len(v.inputs) > 0:
for n in v.inputs:
self.edges.append(IREdge(tail_name=n, head_name=k))
info = self.tensor_info.get(n, '')
if info:
info = info['shape']
self.edges.append(IREdge(tail_name=n, head_name=k, shape=info))
# elif not k and len(v.inputs) > 0:
# for n in v.inputs:
# self.edges.append(IREdge(tail_name=n, head_name=v.label))

# 添加一条指向空节点的边
info = self.tensor_info.get('}', '')
if info:
info = info['shape']
is_end = self.nodes.get('}', '')
if is_end and info:
self.edges.append(IREdge(tail_name='}', head_name='', shape=info))

def get_tensor_info(self, txt_file=''):
# only ap_file
with open(txt_file, 'r') as f:
lines = f.readlines()

pattern1 = re.compile(r'(%[a-z\d_.]*):.+?\[(.+?)]')
pattern2 = re.compile(r'->.+?\[(.+?)]')
pattern3 = re.compile(r'\).+?\[(.+?)]')
pattern4 = re.compile(r'(meta\[relay\.Constant]\[\d*]).+?\[(.+?)]')

self.tensor_info = dict()
flag_None = ''
for tmp_str in lines:
if tmp_str[:2] == 'fn': # match head
match_res = re.findall(pattern1, tmp_str)
for res in match_res:
info = res[1].split(', ')
self.tensor_info[res[0]] = {
'shape': ', '.join(info[:-1]),
'dtype': info[-1]
}
elif tmp_str[0] == '}': # match tail
match_res = re.findall(pattern2, tmp_str)
info = match_res[0].split(', ')
self.tensor_info['}'] = {
'shape': ', '.join(info[:-1]),
'dtype': info[-1]
}
else: # match body
if ' = ' in tmp_str and 'fn' not in tmp_str:
split_node = tmp_str.split(' = ')
if split_node[-1][0] == '(':
continue
match_res = re.findall(pattern3, split_node[-1])
if not match_res:
continue
info = match_res[-1].split(', ')
self.tensor_info[split_node[0].strip()] = {
'shape': ', '.join(info[:-1]),
'dtype': info[-1]
}
# 提取常量meta[relay.Constant][0]
match_res = re.findall(pattern4, split_node[-1])
if not match_res:
continue
for res in match_res:
info = res[-1].split(', ')
self.tensor_info[res[0].strip()] = {
'shape': ', '.join(info[:-1]),
'dtype': info[-1]
}
elif ' = ' in tmp_str and 'fn' in tmp_str:
split_node = tmp_str.split(' = ')
self.tensor_info[split_node[0].strip()] = {
'shape': None,
'dtype': None
}
flag_None = split_node[0].strip()
elif tmp_str.strip()[0] == '(':
continue
else:
match_res = re.findall(pattern3, tmp_str)
info = match_res[0].split(', ')
self.tensor_info[flag_None if flag_None else ''] = {
'shape': ', '.join(info[:-1]),
'dtype': info[-1]
}
# 提取常量meta[relay.Constant][0]
match_res = re.findall(pattern4, tmp_str)
if not match_res:
continue
for res in match_res:
info = res[-1].split(', ')
self.tensor_info[res[0].strip()] = {
'shape': ', '.join(info[:-1]),
'dtype': info[-1]
}


class VisuGraphFuseOps(VisuGraph):
"""Visu FuseOps /
MergeComposite Pass Relay IR"""
def __init__(self, txt_file, save_name='example') -> None:
super(VisuGraphFuseOps, self).__init__(txt_file, save_name)
def __init__(self, txt_file, save_name='example', with_info=False) -> None:
super(VisuGraphFuseOps, self).__init__(txt_file, save_name, with_info)
self.op_args_map = dict()
self.save_name = 'output/visu_{}_relay_ir_pass'.format(save_name)

Expand Down Expand Up @@ -246,6 +345,9 @@ def parse_node(self):
if not self.nodes.get(n, ''):
self.nodes[n] = IRNode(name=n, label=n, color='white')

# 在图的末尾添加一个空节点
self.nodes[''] = IRNode(name='', label='', color='white')

def split_fn_op(self):
pattern1 = re.compile(r'(%\d+).+{(.+)}')
pattern1_ = re.compile(r'(%[a-z]*\d+):|(%[a-zA-Z_]*\d+_\d+):')
Expand All @@ -270,7 +372,7 @@ def split_fn_op(self):
match_op = re.search(pattern3, fn_str).groups(0)
# '%3(%2, %b)'
args = match_op[-1].split(', ')
pnodes[''] = PNode(name='', type='op', inputs=args, body=match_op[0])
pnodes['}'] = PNode(name='', type='op', inputs=args, body=match_op[0])

return pnodes

Expand Down Expand Up @@ -307,8 +409,8 @@ class VisuGraphRUF(VisuGraph):
CombineParallelConv2D /
CombineParallelDense /
CombineParallelBatchMatmul Pass Relay IR"""
def __init__(self, txt_file, save_name='example') -> None:
super(VisuGraphRUF, self).__init__(txt_file, save_name)
def __init__(self, txt_file, save_name='example', with_info=False) -> None:
super(VisuGraphRUF, self).__init__(txt_file, save_name, with_info)
self.save_name = 'output/visu_{}_relay_ir_pass'.format(save_name)

def parse_node(self):
Expand All @@ -326,12 +428,15 @@ def parse_node(self):
if not self.nodes.get(n, ''):
self.nodes[n] = IRNode(name=n, label=n, color='white')

# 在图的末尾添加一个空节点
self.nodes[''] = IRNode(name='', label='', color='white')


class VisuGraphMC(VisuGraphFuseOps):
"""Visu MergeComposite Pass Relay IR"""
# 显示算子融合后的名称
def __init__(self, txt_file, save_name='example') -> None:
super(VisuGraphMC, self).__init__(txt_file, save_name)
def __init__(self, txt_file, save_name='example', with_info=False) -> None:
super(VisuGraphMC, self).__init__(txt_file, save_name, with_info)
self.save_name = 'output/visu_{}_relay_ir_pass'.format(save_name)

def parse_txt(self, txt_file=''):
Expand Down

0 comments on commit 1ae0563

Please sign in to comment.