Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Nov 7, 2024
1 parent 4ba070b commit 803aa39
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 51 deletions.
11 changes: 11 additions & 0 deletions modelconverter/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ class ModelType(str, Enum):
RVC4 = "RVC4"
HAILO = "HAILO"

@classmethod
def from_suffix(cls, suffix: str) -> "ModelType":
if suffix == ".onnx":
return cls.ONNX
elif suffix == ".tflite":
return cls.TFLITE
elif suffix in [".xml", ".bin"]:
return cls.IR
else:
raise ValueError(f"Unsupported model format: {suffix}")


class Format(str, Enum):
NATIVE = "native"
Expand Down
3 changes: 3 additions & 0 deletions modelconverter/hub/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .__main__ import convert

__all__ = ["convert"]
109 changes: 58 additions & 51 deletions modelconverter/hub/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
request_info,
)
from modelconverter.cli.types import License, ModelPrecision
from modelconverter.utils.config import SingleStageConfig
from modelconverter.utils.types import Target

from .hub_requests import Request
Expand Down Expand Up @@ -232,7 +233,7 @@ def version_ls(
limit=limit,
sort=sort,
order=order,
keys=["id", "model_id", "version", "slug", "platforms"],
keys=["id", "version", "slug", "platforms"],
)


Expand Down Expand Up @@ -350,8 +351,6 @@ def instance_ls(
order=order,
keys=[
"id",
"model_version_id",
"model_id",
"slug",
"platforms",
"is_nn_archive",
Expand Down Expand Up @@ -439,7 +438,7 @@ def instance_create(
"model_type": model_type,
"model_precision_type": model_precision_type,
"tags": tags or [],
"input_shape": input_shape,
"input_shape": [input_shape] if input_shape else None,
"quantization_data": quantization_data.name,
"is_deployable": is_deployable,
}
Expand Down Expand Up @@ -486,28 +485,34 @@ def upload(file_path: str, identifier: IdentifierArgument):

@instance.command()
def export(
name: NameArgument,
identifier: IdentifierArgument,
target: TargetArgument,
target_precision: ModelPrecisionOption = ModelPrecision.INT8,
quantization_data: QuantizationOption = Quantization.RANDOM,
name: Optional[str] = None,
) -> Dict[str, Any]:
"""Exports a model instance."""
model_instance_id = get_resource_id(identifier, "modelInstances")
json = {"name": name, "quantization_data": quantization_data.name}
json: Dict[str, Any] = {
"name": name,
"quantization_data": quantization_data.name,
}
if target in [Target.RVC4]:
json["target_precision"] = target_precision
res = Request.post(
f"modelInstances/{model_instance_id}/export/{target.value}",
json=json,
).json()
print(
f"Model instance '{name}' created for {target.name} export with ID '{res['id']}'"
)
return res.json()
return res


@app.command()
def convert(
target: TargetArgument,
path: PathOption,
path: PathOption = None,
name: NameOption = None,
license_type: LicenseTypeOptionRequired = License.UNDEFINED,
config_path: PathOption = None,
Expand All @@ -526,6 +531,7 @@ def convert(
domain: DomainOption = None,
tags: TagsOption = None,
version_id: ModelVersionIDOption = None,
output_dir: OutputDirOption = None,
opts: OptsArgument = None,
) -> Path:
"""Starts the online conversion process."""
Expand All @@ -540,37 +546,13 @@ def convert(
raise ValueError(
"Only single-stage models are supported with online conversion."
)

name = name or cfg.name

cfg = next(iter(cfg.stages.values()))

suffix = cfg.input_model.suffix
if suffix == ".onnx":
model_type = ModelType.ONNX
elif suffix == ".tflite":
model_type = ModelType.TFLITE
elif suffix in [".xml", ".bin"]:
model_type = ModelType.IR
else:
raise ValueError(f"Unsupported model format: {suffix}")

shape = cfg.inputs[0].shape
layout = cfg.inputs[0].layout

if shape is not None:
if layout is not None and "H" in layout and "W" in layout:
h, w = shape[layout.index("H")], shape[layout.index("W")]
version_name = f"{name} {h}x{w}"
elif len(shape) == 4:
if model_type == ModelType.TFLITE:
h, w = shape[1], shape[2]
else:
h, w = shape[2], shape[3]
version_name = f"{name} {h}x{w}"
else:
version_name = name
else:
version_name = name
model_type = ModelType.from_suffix(cfg.input_model.suffix)
version_name = _get_version_name(cfg, model_type, name)

if model_id is None and version_id is None:
try:
Expand All @@ -595,19 +577,7 @@ def convert(
print("`--model-id` is required to create a new model")
exit(1)

versions = Request.get(
"modelVersions/", params={"model_id": model_id}
).json()
if not versions:
version = version or "0.1.0"
else:
max_version = Version(versions[0]["version"])
for v in versions[1:]:
max_version = max(max_version, Version(v["version"]))
max_version = str(max_version)
version_numbers = max_version.split(".")
version_numbers[-1] = str(int(version_numbers[-1]) + 1)
version = ".".join(version_numbers)
version = version or _get_version_number(model_id)

version_id = version_create(
version_name,
Expand All @@ -622,9 +592,10 @@ def convert(
)["id"]

assert version_id is not None
shape = cfg.inputs[0].shape
instance_name = f"{version_name} base instance"
instance_id = instance_create(
instance_name, version_id, model_type, silent=True
instance_name, version_id, model_type, input_shape=shape, silent=True
)["id"]

upload(str(cfg.input_model), instance_id)
Expand All @@ -636,11 +607,11 @@ def convert(
exported_instance_name = f"{version_name} exported to {target.value}"

instance_id = export(
exported_instance_name,
instance_id,
target,
target_precision=target_precision,
quantization_data=quantization_data,
name=exported_instance_name,
)["id"]

with Progress() as progress:
Expand All @@ -653,4 +624,40 @@ def convert(
sleep(5)
pass

return instance_download(instance_id, None)
return instance_download(instance_id, output_dir)


def _get_version_name(
cfg: SingleStageConfig, model_type: ModelType, name: str
) -> str:
shape = cfg.inputs[0].shape
layout = cfg.inputs[0].layout

if shape is not None:
if layout is not None and "H" in layout and "W" in layout:
h, w = shape[layout.index("H")], shape[layout.index("W")]
return f"{name} {h}x{w}"
elif len(shape) == 4:
if model_type == ModelType.TFLITE:
h, w = shape[1], shape[2]
else:
h, w = shape[2], shape[3]
return f"{name} {h}x{w}"
return name


def _get_version_number(model_id: str) -> str:
versions = Request.get(
"modelVersions/", params={"model_id": model_id}
).json()
if not versions:
version = "0.1.0"
else:
max_version = Version(versions[0]["version"])
for v in versions[1:]:
max_version = max(max_version, Version(v["version"]))
max_version = str(max_version)
version_numbers = max_version.split(".")
version_numbers[-1] = str(int(version_numbers[-1]) + 1)
version = ".".join(version_numbers)
return version

0 comments on commit 803aa39

Please sign in to comment.