Skip to content

Commit

Permalink
Custom Loaders Support (#27)
Browse files Browse the repository at this point in the history
* support for custom loaders and datasets

* updated configs

* custom loaders in inspect command

* updated inspect for multi-task labels

* removed custom loader from test config

* deleted comment

* deleted comment

* removed custom dataset

* removed comment

* skipping archiver test untill fixed in luxonis-ml

* [Automated] Updated coverage badge

---------

Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
kozlov721 and actions-user authored May 21, 2024
1 parent 5a31f72 commit 99b1857
Show file tree
Hide file tree
Showing 27 changed files with 239 additions and 227 deletions.
5 changes: 3 additions & 2 deletions configs/classification_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ model:
thickness: 2
include_plot: True

dataset:
name: cifar10_test
loader:
params:
dataset_name: cifar10_test

trainer:
preprocessing:
Expand Down
6 changes: 4 additions & 2 deletions configs/coco_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@ tracker:
wandb_entity: luxonis
is_mlflow: False

dataset:
name: coco_test
loader:
train_view: train
val_view: val
test_view: test

params:
dataset_name: coco_test

trainer:
accelerator: auto
devices: auto
Expand Down
5 changes: 3 additions & 2 deletions configs/detection_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ model:
params:
use_neck: True

dataset:
name: coco_test
loader:
params:
dataset_name: coco_test

trainer:
preprocessing:
Expand Down
5 changes: 3 additions & 2 deletions configs/example_export.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ model:
backbone: MicroNet
task: binary

dataset:
name: coco_test
loader:
params:
dataset_name: coco_test

trainer:
preprocessing:
Expand Down
5 changes: 3 additions & 2 deletions configs/example_tuning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ model:
backbone: MicroNet
task: binary

dataset:
name: coco_test
loader:
params:
dataset_name: coco_test

trainer:
preprocessing:
Expand Down
5 changes: 3 additions & 2 deletions configs/keypoint_bbox_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ model:
predefined_model:
name: KeypointDetectionModel

dataset:
name: coco_test
loader:
params:
dataset_name: coco_test

trainer:
preprocessing:
Expand Down
5 changes: 3 additions & 2 deletions configs/resnet_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ model:
thickness: 2
include_plot: True

dataset:
name: cifar10_test
loader:
params:
dataset_name: cifar10_test

trainer:
batch_size: 4
Expand Down
5 changes: 3 additions & 2 deletions configs/segmentation_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ model:
backbone: MicroNet
task: binary

dataset:
name: coco_test
loader:
params:
dataset_name: coco_test

trainer:
preprocessing:
Expand Down
1 change: 1 addition & 0 deletions luxonis_train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .attached_modules import *
from .core import *
from .models import *
from .utils import *

Expand Down
113 changes: 49 additions & 64 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from typing import Annotated, Optional

import cv2
import torch
import typer
from torch.utils.data import DataLoader

from luxonis_train.utils.registry import LOADERS

app = typer.Typer(help="Luxonis Train CLI", add_completion=False)

Expand Down Expand Up @@ -105,7 +107,6 @@ def inspect(
"""Inspect dataset."""
from lightning.pytorch import seed_everything
from luxonis_ml.data import (
LuxonisDataset,
TrainAugmentations,
ValAugmentations,
)
Expand All @@ -117,7 +118,7 @@ def inspect(
get_unnormalized_images,
)
from luxonis_train.utils.config import Config
from luxonis_train.utils.loaders import LuxonisLoaderTorch, collate_fn
from luxonis_train.utils.loaders import collate_fn
from luxonis_train.utils.types import LabelType

overrides = {}
Expand All @@ -134,79 +135,63 @@ def inspect(

image_size = cfg.trainer.preprocessing.train_image_size

dataset = LuxonisDataset(
dataset_name=cfg.dataset.name,
team_id=cfg.dataset.team_id,
dataset_id=cfg.dataset.id,
bucket_type=cfg.dataset.bucket_type,
bucket_storage=cfg.dataset.bucket_storage,
)
augmentations = (
TrainAugmentations(
image_size=image_size,
augmentations=[
i.model_dump() for i in cfg.trainer.preprocessing.augmentations
],
train_rgb=cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=cfg.trainer.preprocessing.keep_aspect_ratio,
)
if view == "train"
else ValAugmentations(
image_size=image_size,
augmentations=[
i.model_dump() for i in cfg.trainer.preprocessing.augmentations
],
train_rgb=cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=cfg.trainer.preprocessing.keep_aspect_ratio,
)
augmentations = (TrainAugmentations if view == "train" else ValAugmentations)(
image_size=image_size,
augmentations=[i.model_dump() for i in cfg.trainer.preprocessing.augmentations],
train_rgb=cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=cfg.trainer.preprocessing.keep_aspect_ratio,
)

loader_train = LuxonisLoaderTorch(
dataset,
view=view,
augmentations=augmentations,
loader = LOADERS.get(cfg.loader.name)(
view=view, augmentations=augmentations, **cfg.loader.params
)

pytorch_loader_train = torch.utils.data.DataLoader(
loader_train,
batch_size=4,
num_workers=1,
pytorch_loader = DataLoader(
loader,
batch_size=1,
num_workers=0,
collate_fn=collate_fn,
)

if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)

counter = 0
for data in pytorch_loader_train:
imgs, label_dict = data
images = get_unnormalized_images(cfg, imgs)
for i, img in enumerate(images):
for label_type, labels in label_dict.items():
if label_type == LabelType.CLASSIFICATION:
continue
elif label_type == LabelType.BOUNDINGBOX:
img = draw_bounding_box_labels(
img, labels[labels[:, 0] == i][:, 2:], colors="yellow", width=1
)
elif label_type == LabelType.KEYPOINT:
img = draw_keypoint_labels(
img, labels[labels[:, 0] == i][:, 1:], colors="red"
for data in pytorch_loader:
imgs, task_dict = data
for task, label_dict in task_dict.items():
images = get_unnormalized_images(cfg, imgs)
for i, img in enumerate(images):
for label_type, labels in label_dict.items():
if label_type == LabelType.CLASSIFICATION:
continue
elif label_type == LabelType.BOUNDINGBOX:
img = draw_bounding_box_labels(
img,
labels[labels[:, 0] == i][:, 2:],
colors="yellow",
width=1,
)
elif label_type == LabelType.KEYPOINT:
img = draw_keypoint_labels(
img, labels[labels[:, 0] == i][:, 1:], colors="red"
)
elif label_type == LabelType.SEGMENTATION:
img = draw_segmentation_labels(
img, labels[i], alpha=0.8, colors="#5050FF"
)

img_arr = img.permute(1, 2, 0).numpy()
img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
if save_dir is not None:
counter += 1
cv2.imwrite(
os.path.join(save_dir, f"{counter}_{task}.png"), img_arr
)
elif label_type == LabelType.SEGMENTATION:
img = draw_segmentation_labels(
img, labels[i], alpha=0.8, colors="#5050FF"
)

img_arr = img.permute(1, 2, 0).numpy()
img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
if save_dir is not None:
counter += 1
cv2.imwrite(os.path.join(save_dir, f"{counter}.png"), img_arr)
else:
cv2.imshow("img", img_arr)
if cv2.waitKey() == ord("q"):
exit()
else:
cv2.imshow(task, img_arr)
if save_dir is None and cv2.waitKey() == ord("q"):
exit()


@app.command()
Expand Down
39 changes: 5 additions & 34 deletions luxonis_train/callbacks/test_on_train_end.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,14 @@
import lightning.pytorch as pl
from luxonis_ml.data import LuxonisDataset, ValAugmentations
from torch.utils.data import DataLoader

from luxonis_train.utils.config import Config
from luxonis_train.utils.loaders import LuxonisLoaderTorch, collate_fn
import luxonis_train
from luxonis_train.utils.registry import CALLBACKS


@CALLBACKS.register_module()
class TestOnTrainEnd(pl.Callback):
"""Callback to perform a test run at the end of the training."""

def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
cfg: Config = pl_module.cfg

dataset = LuxonisDataset(
dataset_name=cfg.dataset.name,
team_id=cfg.dataset.team_id,
dataset_id=cfg.dataset.id,
bucket_type=cfg.dataset.bucket_type,
bucket_storage=cfg.dataset.bucket_storage,
)

loader_test = LuxonisLoaderTorch(
dataset,
view=cfg.dataset.test_view,
augmentations=ValAugmentations(
image_size=cfg.trainer.preprocessing.train_image_size,
augmentations=[
i.model_dump() for i in cfg.trainer.preprocessing.augmentations
],
train_rgb=cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=cfg.trainer.preprocessing.keep_aspect_ratio,
),
)
pytorch_loader_test = DataLoader(
loader_test,
batch_size=cfg.trainer.batch_size,
num_workers=cfg.trainer.num_workers,
collate_fn=collate_fn,
)
trainer.test(pl_module, pytorch_loader_test)
def on_train_end(
self, trainer: pl.Trainer, pl_module: "luxonis_train.models.LuxonisModel"
) -> None:
trainer.test(pl_module, pl_module._core.pytorch_loaders["test"])
3 changes: 2 additions & 1 deletion luxonis_train/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .archiver import Archiver
from .core import Core
from .exporter import Exporter
from .inferer import Inferer
from .trainer import Trainer
from .tuner import Tuner

__all__ = ["Exporter", "Trainer", "Tuner", "Inferer", "Archiver"]
__all__ = ["Exporter", "Trainer", "Tuner", "Inferer", "Archiver", "Core"]
2 changes: 1 addition & 1 deletion luxonis_train/core/archiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
cfg=self.cfg,
dataset_metadata=self.dataset_metadata,
save_dir=self.run_save_dir,
input_shape=self.loader_train.input_shape,
input_shape=self.loaders["train"].input_shape,
)

self.model_name = self.cfg.model.name
Expand Down
Loading

0 comments on commit 99b1857

Please sign in to comment.