Skip to content

Commit

Permalink
Refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Oct 16, 2024
1 parent 4f32288 commit 1ef8f34
Show file tree
Hide file tree
Showing 22 changed files with 312 additions and 217 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ repos:
hooks:
- id: black
language_version: python3.8
exclude: 'tools/yolov7/yolov7/'

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: no-commit-to-branch
args: ['--branch', 'main', '--branch', 'dev']
args: ['--branch', 'main']

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.10
Expand Down
19 changes: 9 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Tools-CLI

> [!NOTE]
> \[!NOTE\]\
> This is the latest version of tools CLI. If you are looking for the tools web application, please refer to the [web-app](https://github.com/luxonis/tools/tree/web-app) branch.

This application is used for exporting Yolo V5, V6, V7, V8 (OBB, instance segmentation, pose estimation, cls) and Gold YOLO object detection models to .ONNX.

## Running
Expand Down Expand Up @@ -50,14 +49,14 @@ tools shared_with_container/models/yolov6nr4.pt --imgsz "416"

### Arguments

* `model: str` = Path to the model.
* `imgsz: str` = Image input shape in the format `width height` or `width`. Default value `"416 416"`.
* `version: Optional[str]` =
* `use_rvc2: bool` = Whether to export for RVC2 or RVC3 devices. Default value `True`.
* `class_names: Optional[str]` = Optional list of classes separated by a comma, e.g. `"person, dog, cat"`
* `output_remote_url: Optional[str]` = Remote output url for the output .onnx model.
* `config_path: Optional[str]` = Optional path to an optional config.
* `put_file_plugin: Optional[str]` = Which plugin to use. Optional.
- `model: str` = Path to the model.
- `imgsz: str` = Image input shape in the format `width height` or `width`. Default value `"416 416"`.
- `version: Optional[str]` =
- `use_rvc2: bool` = Whether to export for RVC2 or RVC3 devices. Default value `True`.
- `class_names: Optional[str]` = Optional list of classes separated by a comma, e.g. `"person, dog, cat"`
- `output_remote_url: Optional[str]` = Remote output url for the output .onnx model.
- `config_path: Optional[str]` = Optional path to an optional config.
- `put_file_plugin: Optional[str]` = Which plugin to use. Optional.

## Credits

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ optional-dependencies = { dev = { file = ["requirements-dev.txt"] } }

[tool.ruff]
target-version = "py38"
exclude = ["tools/yolov7/yolov7/"]

[tool.ruff.lint]
ignore = ["F403", "B028", "B905", "D1"]
Expand Down
35 changes: 22 additions & 13 deletions tools/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Optional, List
from typing import Optional

import typer

Expand All @@ -18,12 +18,11 @@
YOLOV11_CONVERSION,
Config,
detect_version,
upload_file_to_remote,
resolve_path,
upload_file_to_remote,
)
from tools.utils.constants import MISC_DIR


logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

Expand Down Expand Up @@ -61,13 +60,13 @@ def convert(

if version is not None and version not in YOLO_VERSIONS:
logger.error("Wrong YOLO version selected!")
raise typer.Exit(code=1)
raise typer.Exit(code=1) from None

try:
imgsz = list(map(int, imgsz.split(" "))) if " " in imgsz else [int(imgsz)] * 2
except ValueError:
except ValueError as e:
logger.error('Invalid image size format. Must be "width height" or "width".')
raise typer.Exit(code=1)
raise typer.Exit(code=1) from e

if class_names:
class_names = [class_name.strip() for class_name in class_names.split(",")]
Expand Down Expand Up @@ -97,56 +96,66 @@ def convert(
logger.info("Loading model...")
if version == YOLOV5_CONVERSION:
from tools.yolo.yolov5_exporter import YoloV5Exporter

exporter = YoloV5Exporter(str(model_path), config.imgsz, config.use_rvc2)
elif version == YOLOV6R1_CONVERSION:
from tools.yolov6r1.yolov6_r1_exporter import YoloV6R1Exporter

exporter = YoloV6R1Exporter(str(model_path), config.imgsz, config.use_rvc2)
elif version == YOLOV6R3_CONVERSION:
from tools.yolov6r3.yolov6_r3_exporter import YoloV6R3Exporter

exporter = YoloV6R3Exporter(str(model_path), config.imgsz, config.use_rvc2)
elif version == GOLD_YOLO_CONVERSION:
from tools.yolov6r3.gold_yolo_exporter import GoldYoloExporter

exporter = GoldYoloExporter(str(model_path), config.imgsz, config.use_rvc2)
elif version == YOLOV6R4_CONVERSION:
from tools.yolo.yolov6_exporter import YoloV6R4Exporter

exporter = YoloV6R4Exporter(str(model_path), config.imgsz, config.use_rvc2)
elif version == YOLOV7_CONVERSION:
from tools.yolov7.yolov7_exporter import YoloV7Exporter

exporter = YoloV7Exporter(str(model_path), config.imgsz, config.use_rvc2)
elif version in [YOLOV8_CONVERSION, YOLOV9_CONVERSION, YOLOV11_CONVERSION]:
from tools.yolo.yolov8_exporter import YoloV8Exporter

exporter = YoloV8Exporter(str(model_path), config.imgsz, config.use_rvc2)
elif version == YOLOV10_CONVERSION:
from tools.yolo.yolov10_exporter import YoloV10Exporter

exporter = YoloV10Exporter(str(model_path), config.imgsz, config.use_rvc2)
else:
logger.error("Unrecognized model version.")
raise typer.Exit(code=1)
raise typer.Exit(code=1) from None
logger.info("Model loaded.")
except Exception as e:
logger.error(f"Error creating exporter: {e}")
raise typer.Exit(code=1)
raise typer.Exit(code=1) from e

# Export model
try:
logger.info("Exporting model...")
exporter.export_onnx()
logger.info("Model exported.")
except Exception as e:
logger.error(f"Error exporting model: {e}")
raise typer.Exit(code=1)
raise typer.Exit(code=1) from e
# Create NN archive
try:
logger.info("Creating NN archive...")
exporter.export_nn_archive(config.class_names)
logger.info("NN archive created.")
except Exception as e:
logger.error(f"Error creating NN archive: {e}")
raise typer.Exit(code=1)
raise typer.Exit(code=1) from e

# Upload to remote
if config.output_remote_url:
upload_file_to_remote(exporter.f_nn_archive, config.output_remote_url, config.put_file_plugin)
upload_file_to_remote(
exporter.f_nn_archive, config.output_remote_url, config.put_file_plugin
)
logger.info(f"Uploaded NN archive to {config.output_remote_url}")


Expand Down
17 changes: 8 additions & 9 deletions tools/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
from __future__ import annotations

from .backbones import YoloV6BackBone
from .exporter import Exporter
from .heads import (
DetectV6R1,
DetectV6R3,
DetectV6R4m,
DetectV6R4s,
DetectV8,
PoseV8,
OBBV8,
SegmentV8,
ClassifyV8,
DetectV5,
DetectV6R1,
DetectV6R3,
DetectV6R4m,
DetectV6R4s,
DetectV7,
DetectV8,
DetectV10,
PoseV8,
SegmentV8,
)
from .exporter import Exporter
from .stage2 import Multiplier


__all__ = [
"YoloV6BackBone",
"DetectV6R1",
Expand Down
43 changes: 27 additions & 16 deletions tools/modules/exporter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

from datetime import datetime
import os
from typing import List, Tuple, Optional
from datetime import datetime
from typing import List, Optional, Tuple

import onnx
import onnxsim
import torch
from luxonis_ml.nn_archive import ArchiveGenerator
from luxonis_ml.nn_archive.config_building_blocks import (
DataType,
Head,
InputType,
DataType,
)
from luxonis_ml.nn_archive.config_building_blocks.base_models.head_metadata import (
HeadYOLOMetadata,
Expand All @@ -22,13 +22,14 @@

class Exporter:
"""Exporter class to export models to ONNX and NN archive formats."""

def __init__(
self,
model_path: str,
imgsz: Tuple[int, int],
use_rvc2: bool,
self,
model_path: str,
imgsz: Tuple[int, int],
use_rvc2: bool,
subtype: str,
output_names: List[str] = ["output"],
output_names: List[str] = None,
all_output_names: Optional[List[str]] = None,
):
"""
Expand All @@ -53,8 +54,13 @@ def __init__(
self.number_of_channels = None
self.subtype = subtype
self.output_names = output_names
self.all_output_names = all_output_names if all_output_names is not None else output_names
self.output_folder = (OUTPUTS_DIR / f"{self.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}").resolve()
self.all_output_names = (
all_output_names if all_output_names is not None else output_names
)
self.output_folder = (
OUTPUTS_DIR
/ f"{self.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
).resolve()
# If output directory does not exist, create it
if not self.output_folder.exists():
self.output_folder.mkdir(parents=True)
Expand Down Expand Up @@ -107,7 +113,7 @@ def make_nn_archive(
n_prototypes: Optional[int] = None,
n_keypoints: Optional[int] = None,
is_softmax: Optional[bool] = None,
output_kwargs: Optional[dict] = {},
output_kwargs: Optional[dict] = None,
):
"""Export the model to NN archive format.
Expand All @@ -130,7 +136,10 @@ def make_nn_archive(
executables_paths = [str(self.f_onnx), stage2_executable_path]
else:
executables_paths = [str(self.f_onnx)]


if output_kwargs is None:
output_kwargs = {}

archive = ArchiveGenerator(
archive_name=self.model_name,
save_path=str(self.output_folder),
Expand Down Expand Up @@ -165,7 +174,7 @@ def make_nn_archive(
Head(
parser=parser,
metadata=HeadYOLOMetadata(
yolo_outputs=self.output_names,
yolo_outputs=self.output_names,
subtype=self.subtype,
n_classes=n_classes,
classes=class_list,
Expand All @@ -190,14 +199,16 @@ def make_nn_archive(
def export_nn_archive(self, class_names: Optional[List[str]] = None):
"""
Export the model to NN archive format.
Args:
class_list (Optional[List[str]], optional): List of class names. Defaults to None.
"""
nc = self.model.detect.nc
# If class names are provided, use them
if class_names is not None:
assert len(class_names) == nc, f"Number of the given class names {len(class_names)} does not match number of classes {nc} provided in the model!"
assert (
len(class_names) == nc
), f"Number of the given class names {len(class_names)} does not match number of classes {nc} provided in the model!"
names = class_names
else:
# Check if the model has a names attribute
Expand All @@ -206,4 +217,4 @@ def export_nn_archive(self, class_names: Optional[List[str]] = None):
else:
names = [f"Class_{i}" for i in range(nc)]

self.make_nn_archive(names, nc)
self.make_nn_archive(names, nc)
Loading

0 comments on commit 1ef8f34

Please sign in to comment.