Skip to content

Commit

Permalink
Feat/upload nn archive (#79)
Browse files Browse the repository at this point in the history
* Update the uplod code

* Add upload logging
  • Loading branch information
HonzaCuhel authored Jun 27, 2024
1 parent b92f9ec commit 6393717
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tools"
version = "0.0.1"
version = "0.0.2"
description = "Converter for YOLO models into .ONNX format."
readme = "README.md"
requires-python = ">=3.8"
Expand Down
3 changes: 2 additions & 1 deletion tools/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def convert(

# Upload to remote
if config.output_remote_url:
upload_file_to_remote(exporter.f_onnx, config.output_remote_url, config.put_file_plugin)
upload_file_to_remote(exporter.f_nn_archive, config.output_remote_url, config.put_file_plugin)
logger.info(f"Uploaded NN archive to {config.output_remote_url}")


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions tools/modules/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
self.model_name = os.path.basename(self.model_path).split(".")[0]
# Set up file paths
self.f_onnx = None
self.f_nn_archive = None
self.use_rvc2 = use_rvc2
self.number_of_channels = None
self.subtype = subtype
Expand Down Expand Up @@ -103,6 +104,7 @@ def make_nn_archive(
conf_threshold (float): Confidence threshold
max_det (int): Maximum number of detections
"""
self.f_nn_archive = (self.output_folder / f"{self.model_name}.tar.xz").resolve()
archive = ArchiveGenerator(
archive_name=self.model_name,
save_path=str(self.output_folder),
Expand Down
2 changes: 2 additions & 0 deletions tools/yolo/yolov8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def export_nn_archive(self, class_names: Optional[List[str]] = None):
assert len(class_names) == len(names), f"Number of the given class names {len(class_names)} does not match number of classes {len(names)} provided in the model!"
names = class_names

self.f_nn_archive = (self.output_folder / f"{self.model_name}.tar.xz").resolve()

if self.mode == DETECT_MODE:
self.make_nn_archive(names, self.model.model[-1].nc)
elif self.mode == SEGMENT_MODE:
Expand Down

0 comments on commit 6393717

Please sign in to comment.