Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Archiver and add BaseHead #139

Merged
merged 24 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
92eb83e
feat: add Archivable interface
sokovninn Nov 25, 2024
a43a091
fix: archivable export
sokovninn Nov 25, 2024
cea5650
feat: add BaseHead class
sokovninn Nov 26, 2024
73f6451
fix: remove archivable inetrface
sokovninn Nov 26, 2024
ac65e6f
feat: update EfficientBBoxHead to inherit from BaseHead
sokovninn Nov 26, 2024
c0fe60b
style: formatting
sokovninn Nov 26, 2024
3dfed24
fix: remove Archiable interface from head classes
sokovninn Nov 26, 2024
34b0c29
fix: improve BaseHead structure
sokovninn Nov 27, 2024
5ec6ca5
feat: update all heads to inherit from BaseHead
sokovninn Nov 27, 2024
616c4c3
feat: add a warning about not implemented get_custom_head_config in t…
sokovninn Nov 27, 2024
35e9adc
fix: update test archive config
sokovninn Nov 27, 2024
034891f
test: sort fields in configs before comparison
sokovninn Nov 27, 2024
105f82f
fix: remove deprecated enums
sokovninn Nov 27, 2024
7b310ec
fix: add subtype to the keypoint head
sokovninn Nov 27, 2024
ba1512b
fix: update parser for the EfficientKeypointBBoxHead
sokovninn Nov 27, 2024
1ae48a8
Merge branch 'feat/archiver-refactor' of https://github.com/luxonis/l…
sokovninn Nov 27, 2024
27dff57
fix: update test config
sokovninn Nov 27, 2024
a58163d
fix: correct configs merge
sokovninn Nov 27, 2024
7d0ccc0
fix: move EMACallback in the light detection model config
sokovninn Nov 27, 2024
238fd35
fix: revert dataset_name in the detection light config
sokovninn Nov 28, 2024
7f590f6
added coverage rule
kozlov721 Nov 28, 2024
050ccc4
fix: remove redundant init in the BaseHead
sokovninn Nov 28, 2024
dbc9027
fix: remove deep_merge_dicts
sokovninn Nov 28, 2024
d2a05f4
fix: remove inheritance from Generic in BaseHead
sokovninn Nov 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions configs/detection_light_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ trainer:
n_log_images: 8

callbacks:
- name: EMACallback
params:
decay: 0.9999
use_dynamic_decay: True
decay_tau: 2000
- name: ExportOnTrainEnd
- name: TestOnTrainEnd

Expand All @@ -43,10 +48,3 @@ trainer:
params:
T_max: *epochs
eta_min: 0.000384

callbacks:
- name: EMACallback
params:
decay: 0.9999
use_dynamic_decay: True
decay_tau: 2000
12 changes: 7 additions & 5 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,11 @@ def archive(
return self._archive(path)

def _archive(self, path: str | Path | None = None) -> Path:
from .utils.archive_utils import get_heads, get_inputs, get_outputs
from .utils.archive_utils import (
get_head_configs,
get_inputs,
get_outputs,
)

archive_name = self.cfg.archiver.name or self.cfg.model.name
archive_save_directory = Path(self.run_save_dir, "archive")
Expand Down Expand Up @@ -761,11 +765,9 @@ def _mult(lst: list[float | int]) -> list[float]:
}
)

heads = get_heads(
self.cfg,
heads = get_head_configs(
self.lightning_module,
outputs,
self.loaders["train"].get_classes(),
self.lightning_module.nodes, # type: ignore
)

model = {
Expand Down
159 changes: 35 additions & 124 deletions luxonis_train/core/utils/archive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@
)
from onnx.onnx_pb import TensorProto

from luxonis_train.config import Config
from luxonis_train.nodes.base_node import BaseNode
from luxonis_train.nodes.enums.head_categorization import (
ImplementedHeads,
ImplementedHeadsIsSoxtmaxed,
)
from luxonis_train.models import LuxonisLightningModule
from luxonis_train.nodes.heads import BaseHead

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,71 +100,9 @@ def _get_onnx_inputs(onnx_path: Path) -> dict[str, MetadataDict]:
return inputs


def _get_classes(
node_name: str, node_task: str | None, classes: dict[str, list[str]]
) -> list[str]:
if not node_task:
match node_name:
case "ClassificationHead":
node_task = "classification"
case "EfficientBBoxHead":
node_task = "boundingbox"
case "SegmentationHead" | "BiSeNetHead" | "DDRNetSegmentationHead":
node_task = "segmentation"
case "EfficientKeypointBBoxHead":
node_task = "keypoints"
case _: # pragma: no cover
raise ValueError("Node does not map to a default task.")

return classes.get(node_task, [])


def _get_head_specific_parameters(
nodes: dict[str, BaseNode], head_name: str, head_alias: str
) -> dict:
"""Get parameters specific to head.

@type nodes: dict[str, BaseNode]
@param nodes: Dictionary of nodes.
@type head_name: str
@param head_name: Name of the head (e.g. 'EfficientBBoxHead').
@type head_alias: str
@param head_alias: Alias of the head (e.g. 'detection_head').
"""

parameters = {}
if head_name == "ClassificationHead":
parameters["is_softmax"] = getattr(
ImplementedHeadsIsSoxtmaxed, head_name
).value
elif head_name == "EfficientBBoxHead":
parameters["subtype"] = "yolov6"
head_node = nodes[head_alias]
parameters["iou_threshold"] = head_node.iou_thres
parameters["conf_threshold"] = head_node.conf_thres
parameters["max_det"] = head_node.max_det
elif head_name in [
"SegmentationHead",
"BiSeNetHead",
"DDRNetSegmentationHead",
]:
parameters["is_softmax"] = getattr(
ImplementedHeadsIsSoxtmaxed, head_name
).value
elif head_name == "EfficientKeypointBBoxHead":
# or appropriate subtype
head_node = nodes[head_alias]
parameters["iou_threshold"] = head_node.iou_thres
parameters["conf_threshold"] = head_node.conf_thres
parameters["max_det"] = head_node.max_det
parameters["n_keypoints"] = head_node.n_keypoints
else: # pragma: no cover
raise ValueError("Unknown head name")
return parameters


def _get_head_outputs(
outputs: list[dict], head_name: str, head_type: str
outputs: list[dict],
head_name: str,
) -> list[str]:
"""Get model outputs in a head-specific format.

Expand All @@ -177,8 +111,6 @@ def _get_head_outputs(
@type head_name: str
@param head_name: Type of the head (e.g. 'EfficientBBoxHead') or its
custom alias.
@type head_type: str
@param head_name: Type of the head (e.g. 'EfficientBBoxHead').
@rtype: list[str]
@return: List of output names.
"""
Expand All @@ -192,63 +124,42 @@ def _get_head_outputs(
return output_names


def get_heads(
cfg: Config,
def get_head_configs(
lightning_module: LuxonisLightningModule,
outputs: list[dict],
class_dict: dict[str, list[str]],
nodes: dict[str, BaseNode],
) -> list[dict]:
"""Get model heads.

@type cfg: Config
@param cfg: Configuration object.
@type lightning_module: LuxonisLightningModule
@param lightning_module: Lightning module.
@type outputs: list[dict]
@param outputs: List of model outputs.
@type class_dict: dict[str, list[str]]
@param class_dict: Dictionary of classes.
@type nodes: dict[str, BaseNode]
@param nodes: Dictionary of nodes.
@param outputs: List of NN Archive outputs.
@rtype: list[dict]
@return: List of head configurations.
"""
heads = []
head_configs = []
head_names = set()
for node in cfg.model.nodes:
node_name = node.name
node_alias = node.alias or node_name
if "aux-segmentation" in node_alias:

for node_name, node in lightning_module.nodes.items():
if not isinstance(node, BaseHead) or node.remove_on_export:
continue
if node_alias in cfg.model.outputs:
if node_name in ImplementedHeads.__members__:
parser = getattr(ImplementedHeads, node_name).value
task = node.task
if isinstance(task, dict):
task = str(next(iter(task.values())))

classes = _get_classes(node_name, task, class_dict)

export_output_names = nodes[node_alias].export_output_names
if export_output_names is not None:
head_outputs = export_output_names
else:
head_outputs = _get_head_outputs(
outputs, node_alias, node_name
)

if node_alias in head_names:
curr_head_name = f"{node_alias}_{len(head_names)}" # add suffix if name is already present
else:
curr_head_name = node_alias
head_names.add(curr_head_name)
head_dict = {
"name": curr_head_name,
"parser": parser,
"metadata": {
"classes": classes,
"n_classes": len(classes),
},
"outputs": head_outputs,
}
head_dict["metadata"].update(
_get_head_specific_parameters(nodes, node_name, node_alias)
)
heads.append(head_dict)
return heads
try:
head_config = node.get_head_config()
except NotImplementedError as e:
logger.error(f"Failed to archive head `{node_name}`: {e}")
continue
head_name = (
node_name
if node_name not in head_names
else f"{node_name}_{len(head_names)}"
)
head_names.add(head_name)

head_outputs = node.export_output_names or _get_head_outputs(
outputs, node_name
)
head_config.update({"name": head_name, "outputs": head_outputs})

head_configs.append(head_config)

return head_configs
8 changes: 8 additions & 0 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from luxonis_train.config import AttachedModuleConfig, Config
from luxonis_train.nodes import BaseNode
from luxonis_train.nodes.heads import BaseHead
from luxonis_train.utils import (
DatasetMetadata,
Kwargs,
Expand Down Expand Up @@ -353,6 +354,13 @@ def _initiate_nodes(
dataset_metadata=self.dataset_metadata,
**node_kwargs,
)
if isinstance(node, BaseHead):
try:
node.get_custom_head_config()
except NotImplementedError:
logger.warning(
f"Head {node_name} does not implement get_custom_head_config method. Archivation of this head will fail."
)
node_outputs = node.run(node_dummy_inputs)

dummy_inputs[node_name] = node_outputs
Expand Down
23 changes: 0 additions & 23 deletions luxonis_train/nodes/enums/head_categorization.py

This file was deleted.

2 changes: 2 additions & 0 deletions luxonis_train/nodes/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_head import BaseHead
from .bisenet_head import BiSeNetHead
from .classification_head import ClassificationHead
from .ddrnet_segmentation_head import DDRNetSegmentationHead
Expand All @@ -7,6 +8,7 @@
from .segmentation_head import SegmentationHead

__all__ = [
"BaseHead",
"BiSeNetHead",
"ClassificationHead",
"EfficientBBoxHead",
Expand Down
52 changes: 52 additions & 0 deletions luxonis_train/nodes/heads/base_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from luxonis_train.nodes.base_node import (
BaseNode,
ForwardInputT,
ForwardOutputT,
)


class BaseHead(
BaseNode[ForwardInputT, ForwardOutputT],
):
"""Base class for all heads in the model.

@type parser: str | None
@ivar parser: Parser to use for the head.
"""

parser: str | None = None

def get_head_config(self) -> dict:
"""Get head configuration.

@rtype: dict
@return: Head configuration.
"""
config = self._get_base_head_config()
custom_config = self.get_custom_head_config()
config["metadata"].update(custom_config)
return config

def _get_base_head_config(self) -> dict:
"""Get base head configuration.

@rtype: dict
@return: Base head configuration.
"""
return {
"parser": self.parser,
"metadata": {
"classes": self.class_names,
"n_classes": self.n_classes,
},
}

def get_custom_head_config(self) -> dict:
"""Get custom head configuration.

@rtype: dict
@return: Custom head configuration.
"""
raise NotImplementedError(
"get_custom_head_config method must be implemented."
)
15 changes: 13 additions & 2 deletions luxonis_train/nodes/heads/bisenet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
from torch import Tensor, nn

from luxonis_train.enums import TaskType
from luxonis_train.nodes.base_node import BaseNode
from luxonis_train.nodes.blocks import ConvModule
from luxonis_train.nodes.heads import BaseHead
from luxonis_train.utils import infer_upscale_factor


class BiSeNetHead(BaseNode[Tensor, Tensor]):
class BiSeNetHead(BaseHead[Tensor, Tensor]):
in_height: int
in_width: int
in_channels: int

tasks: list[TaskType] = [TaskType.SEGMENTATION]
parser: str = "SegmentationParser"

def __init__(self, intermediate_channels: int = 64, **kwargs: Any):
"""BiSeNet segmentation head.
Expand Down Expand Up @@ -56,3 +57,13 @@ def forward(self, inputs: Tensor) -> Tensor:
x = self.conv_3x3(inputs)
x = self.conv_1x1(x)
return self.upscale(x)

def get_custom_head_config(self) -> dict:
"""Returns custom head configuration.

@rtype: dict
@return: Custom head configuration.
"""
return {
"is_softmax": False,
}
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
Loading