diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 702c167..fe3e04b 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -68,7 +68,10 @@ def import_graph(hl_graph, model, args, input_names=None, verbose=False): # Run the Pytorch graph to get a trace and generate a graph from it trace, out = torch.jit._get_trace_graph(model, args) - torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) + try: + torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) + except TypeError as e: + torch_graph = trace # Dump list of nodes (DEBUG only) if verbose: