From eb67e8aa1bc748b86783d0a9d761ea1f1126deae Mon Sep 17 00:00:00 2001 From: HonzaCuhel Date: Mon, 2 Sep 2024 13:52:16 +0200 Subject: [PATCH] Add parser parametrization to the exporter --- tools/modules/exporter.py | 4 +++- tools/yolo/yolov8_exporter.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tools/modules/exporter.py b/tools/modules/exporter.py index caadfcd..bfae4cc 100644 --- a/tools/modules/exporter.py +++ b/tools/modules/exporter.py @@ -101,6 +101,7 @@ def make_nn_archive( iou_threshold: float = 0.5, conf_threshold: float = 0.5, max_det: int = 300, + parser: str = "YOLO", stage2_executable_path: Optional[str] = None, postprocessor_path: Optional[str] = None, n_prototypes: Optional[int] = None, @@ -116,6 +117,7 @@ def make_nn_archive( iou_threshold (float): Intersection over Union threshold conf_threshold (float): Confidence threshold max_det (int): Maximum number of detections + parser (str): Parser type, defaults to "YOLO" 2stage_executable_path (Optional[str], optional): Path to the executables. Defaults to None. postprocessor_path (Optional[str], optional): Path to the postprocessor. Defaults to None. n_prototypes (Optional[int], optional): Number of prototypes. Defaults to None. @@ -161,7 +163,7 @@ def make_nn_archive( ], "heads": [ Head( - parser="YOLO", + parser=parser, metadata=HeadYOLOMetadata( yolo_outputs=self.output_names, subtype=self.subtype, diff --git a/tools/yolo/yolov8_exporter.py b/tools/yolo/yolov8_exporter.py index bfc3b6a..e57facb 100644 --- a/tools/yolo/yolov8_exporter.py +++ b/tools/yolo/yolov8_exporter.py @@ -152,6 +152,7 @@ def export_nn_archive(self, class_names: Optional[List[str]] = None): self.make_nn_archive( names, self.model.model[-1].nc, + parser="YOLOExtendedParser", stage2_executable_path=str(self.f_stage2_onnx), postprocessor_path=self.stage2_filename, n_prototypes=32, @@ -177,6 +178,7 @@ def export_nn_archive(self, class_names: Optional[List[str]] = None): self.make_nn_archive( names, self.model.model[-1].nc, + parser="YOLOExtendedParser", n_keypoints=17, output_kwargs={ "keypoints_outputs": ["kpt_output"]