diff --git a/docs/developer-guide/api/README.md b/docs/developer-guide/api/README.md index c34741003..a351c806b 100644 --- a/docs/developer-guide/api/README.md +++ b/docs/developer-guide/api/README.md @@ -154,7 +154,8 @@ - [`onnx_impl_utils.compute_onnx_pool_padding`](./concrete.ml.onnx.onnx_impl_utils.md#function-compute_onnx_pool_padding): Compute any additional padding needed to compute pooling layers. - [`onnx_impl_utils.numpy_onnx_pad`](./concrete.ml.onnx.onnx_impl_utils.md#function-numpy_onnx_pad): Pad a tensor according to ONNX spec, using an optional custom pad value. - [`onnx_impl_utils.onnx_avgpool_compute_norm_const`](./concrete.ml.onnx.onnx_impl_utils.md#function-onnx_avgpool_compute_norm_const): Compute the average pooling normalization constant. -- [`onnx_model_manipulations.clean_graph_after_node`](./concrete.ml.onnx.onnx_model_manipulations.md#function-clean_graph_after_node): Clean the graph of the onnx model by removing nodes after the given node name. +- [`onnx_model_manipulations.clean_graph_after_node_name`](./concrete.ml.onnx.onnx_model_manipulations.md#function-clean_graph_after_node_name): Clean the graph of the onnx model by removing nodes after the given node name. +- [`onnx_model_manipulations.clean_graph_after_node_op_type`](./concrete.ml.onnx.onnx_model_manipulations.md#function-clean_graph_after_node_op_type): Clean the graph of the onnx model by removing nodes after the given node type. - [`onnx_model_manipulations.keep_following_outputs_discard_others`](./concrete.ml.onnx.onnx_model_manipulations.md#function-keep_following_outputs_discard_others): Keep the outputs given in outputs_to_keep and remove the others from the model. - [`onnx_model_manipulations.remove_identity_nodes`](./concrete.ml.onnx.onnx_model_manipulations.md#function-remove_identity_nodes): Remove identity nodes from a model. - [`onnx_model_manipulations.remove_node_types`](./concrete.ml.onnx.onnx_model_manipulations.md#function-remove_node_types): Remove unnecessary nodes from the ONNX graph. diff --git a/docs/developer-guide/api/concrete.ml.onnx.onnx_model_manipulations.md b/docs/developer-guide/api/concrete.ml.onnx.onnx_model_manipulations.md index a031fe99b..a86250e1a 100644 --- a/docs/developer-guide/api/concrete.ml.onnx.onnx_model_manipulations.md +++ b/docs/developer-guide/api/concrete.ml.onnx.onnx_model_manipulations.md @@ -99,10 +99,14 @@ ______________________________________________________________________ -## function `clean_graph_after_node` +## function `clean_graph_after_node_name` ```python -clean_graph_after_node(onnx_model: ModelProto, node_name: str) +clean_graph_after_node_name( + onnx_model: ModelProto, + node_name: str, + fail_if_not_found: bool = True +) ``` Clean the graph of the onnx model by removing nodes after the given node name. @@ -111,3 +115,34 @@ Clean the graph of the onnx model by removing nodes after the given node name. - `onnx_model` (onnx.ModelProto): The onnx model. - `node_name` (str): The node's name whose following nodes will be removed. +- `fail_if_not_found` (bool): If true, abort if the node name is not found + +**Raises:** + +- `ValueError`: if the node name is not found and if fail_if_not_found is set + +______________________________________________________________________ + + + +## function `clean_graph_after_node_op_type` + +```python +clean_graph_after_node_op_type( + onnx_model: ModelProto, + node_op_type: str, + fail_if_not_found: bool = True +) +``` + +Clean the graph of the onnx model by removing nodes after the given node type. + +**Args:** + +- `onnx_model` (onnx.ModelProto): The onnx model. +- `node_op_type` (str): The node's op_type whose following nodes will be removed. +- `fail_if_not_found` (bool): If true, abort if the node op_type is not found + +**Raises:** + +- `ValueError`: if the node op_type is not found and if fail_if_not_found is set