From 8e14b20dbe9888630b07d70e2a55cf080b5bd5a9 Mon Sep 17 00:00:00 2001 From: Julien Blanchon Date: Thu, 23 Jun 2022 22:23:23 +0200 Subject: [PATCH] Fix TypeError for Pytorch Model #94 --- hiddenlayer/pytorch_builder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 702c167..fe3e04b 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -68,7 +68,10 @@ def import_graph(hl_graph, model, args, input_names=None, verbose=False): # Run the Pytorch graph to get a trace and generate a graph from it trace, out = torch.jit._get_trace_graph(model, args) - torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) + try: + torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) + except TypeError as e: + torch_graph = trace # Dump list of nodes (DEBUG only) if verbose: