Skip to content

Commit

Permalink
Add ONNXModifier for optimising the ONNX model before converting for …
Browse files Browse the repository at this point in the history
…RVC4 execution (#55)

Co-authored-by: Martin Kozlovský <[email protected]>
  • Loading branch information
ptoupas and kozlov721 authored Dec 13, 2024
1 parent 9d09fd0 commit deac688
Show file tree
Hide file tree
Showing 11 changed files with 1,278 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unittests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_S3_ENDPOINT_URL: ${{ secrets.AWS_S3_ENDPOINT_URL }}
GOOGLE_APPLICATION_CREDENTIALS: ${{ secrets.GCP_CREDENTIALS }}
HUB_AI_API_KEY: ${{ secrets.HUB_AI_API_KEY }}
run: python -m pytest tests/test_utils

1 change: 1 addition & 0 deletions modelconverter/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def get_target_specific_options(
json_cfg = cfg.model_dump(mode="json")
options = {
"disable_onnx_simplification": cfg.disable_onnx_simplification,
"disable_onnx_optimisation": cfg.disable_onnx_optimisation,
"inputs": json_cfg["inputs"],
}
if target == "rvc4":
Expand Down
1 change: 1 addition & 0 deletions modelconverter/packages/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.outputs = {out.name: out for out in config.outputs}
self.keep_intermediate_outputs = config.keep_intermediate_outputs
self.disable_onnx_simplification = config.disable_onnx_simplification
self.disable_onnx_optimisation = config.disable_onnx_optimisation

self.model_name = self.input_model.stem

Expand Down
20 changes: 20 additions & 0 deletions modelconverter/packages/rvc4/exporter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil
import subprocess
import time
Expand All @@ -6,6 +7,7 @@
from typing import Any, Dict, List, NamedTuple, Optional, cast

from modelconverter.utils import (
ONNXModifier,
exit_with,
onnx_attach_normalization_to_inputs,
read_image,
Expand Down Expand Up @@ -57,6 +59,24 @@ def __init__(self, config: SingleStageConfig, output_dir: Path):
self._attach_suffix(self.input_model, "modified.onnx"),
self.inputs,
)

if not config.disable_onnx_optimisation:
onnx_modifier = ONNXModifier(
model_path=self.input_model,
output_path=self._attach_suffix(
self.input_model, "modified_optimised.onnx"
),
)

if (
onnx_modifier.modify_onnx()
and onnx_modifier.compare_outputs()
):
logger.info("ONNX model has been optimised for RVC4.")
shutil.move(onnx_modifier.output_path, self.input_model)
else:
if os.path.exists(onnx_modifier.output_path):
os.remove(onnx_modifier.output_path)
else:
logger.warning(
"Input file type is not ONNX. Skipping pre-processing."
Expand Down
3 changes: 2 additions & 1 deletion modelconverter/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
modelconverter_config_to_nn,
process_nn_archive,
)
from .onnx_tools import onnx_attach_normalization_to_inputs
from .onnx_tools import ONNXModifier, onnx_attach_normalization_to_inputs
from .subprocess import subprocess_run

__all__ = [
Expand All @@ -37,6 +37,7 @@
"S3Exception",
"SubprocessException",
"exit_with",
"ONNXModifier",
"onnx_attach_normalization_to_inputs",
"read_calib_dir",
"read_image",
Expand Down
1 change: 1 addition & 0 deletions modelconverter/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class SingleStageConfig(CustomBaseModel):

keep_intermediate_outputs: bool = True
disable_onnx_simplification: bool = False
disable_onnx_optimisation: bool = False
output_remote_url: Optional[str] = None
put_file_plugin: Optional[str] = None

Expand Down
Loading

0 comments on commit deac688

Please sign in to comment.