Skip to content

Commit

Permalink
Run inference on the same device as used in the Trainer (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
sokovninn authored Nov 8, 2024
1 parent 509446b commit f845922
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 48 deletions.
2 changes: 1 addition & 1 deletion luxonis_train/attached_modules/base_attached_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def prepare(
set(self.supported_tasks) & set(self.node_tasks)
)
x = self.get_input_tensors(inputs)
if labels is None:
if labels is None or len(labels) == 0:
return x, None # type: ignore
label, task_type = self._get_label(labels)
if task_type in [TaskType.CLASSIFICATION, TaskType.SEGMENTATION]:
Expand Down
132 changes: 85 additions & 47 deletions luxonis_train/core/utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
import cv2
import numpy as np
import torch
import torch.utils.data as torch_data
from luxonis_ml.data import LuxonisDataset
from torch import Tensor

import luxonis_train
from luxonis_train.attached_modules.visualizers import get_denormalized_images
from luxonis_train.loaders import LuxonisLoaderTorch

IMAGE_FORMATS = {
".bmp",
Expand Down Expand Up @@ -122,6 +125,61 @@ def infer_from_video(
writer.release()


def infer_from_loader(
model: "luxonis_train.core.LuxonisModel",
loader: torch_data.DataLoader,
save_dir: Path | None,
img_paths: list[Path] | None = None,
) -> None:
"""Runs inference on images from the dataset.
@type model: L{LuxonisModel}
@param model: The model to use for inference.
@type loader: torch_data.DataLoader
@param loader: The loader to use for inference.
@type save_dir: str | Path | None
@param save_dir: The directory to save the visualizations to.
@type img_paths: list[Path] | None
@param img_paths: The paths to the images.
"""

predictions = model.pl_trainer.predict(model.lightning_module, loader)

broken = False
if predictions is None:
return

for i, outputs in enumerate(predictions):
if broken: # pragma: no cover
break
visualizations = outputs.visualizations # type: ignore
batch_size = next(
iter(next(iter(visualizations.values())).values())
).shape[0]
renders = process_visualizations(
visualizations,
batch_size=batch_size,
)
for j in range(batch_size):
for (node_name, viz_name), visualizations in renders.items():
viz = visualizations[j]
if save_dir is not None:
if img_paths is not None:
img_path = img_paths[i * batch_size + j]
name = f"{img_path.stem}_{node_name}_{viz_name}"
else:
name = f"{node_name}_{viz_name}_{i * batch_size + j}"
cv2.imwrite(str(save_dir / f"{name}.png"), viz)
else:
cv2.imshow(f"{node_name}/{viz_name}", viz)

if not save_dir and window_closed(): # pragma: no cover
broken = True
break

cv2.destroyAllWindows()


def infer_from_directory(
model: "luxonis_train.core.LuxonisModel",
img_paths: Iterable[Path],
Expand All @@ -136,27 +194,34 @@ def infer_from_directory(
@type save_dir: Path | None
@param save_dir: The directory to save the visualizations to.
"""
for img_path in img_paths:
img = cv2.imread(str(img_path))
outputs = prepare_and_infer_image(model, img)
renders = process_visualizations(outputs.visualizations, batch_size=1)
img_paths = list(img_paths)

def generator():
for img_path in img_paths:
yield {
"file": img_path,
}

dataset_name = "infer_from_directory"
dataset = LuxonisDataset(dataset_name=dataset_name, delete_existing=True)
dataset.add(generator())
dataset.make_splits(
{"train": 0.0, "val": 0.0, "test": 1.0}, replace_old_splits=True
)

for (node_name, viz_name), [viz] in renders.items():
if save_dir is not None:
cv2.imwrite(
str(
save_dir
/ f"{img_path.stem}_{node_name}_{viz_name}.png"
),
viz,
)
else: # pragma: no cover
cv2.imshow(f"{node_name}/{viz_name}", viz)
loader = LuxonisLoaderTorch(
dataset_name=dataset_name,
image_source="image",
view="test",
augmentations=model.val_augmentations,
)
loader = torch_data.DataLoader(
loader, batch_size=model.cfg.trainer.batch_size
)

if not save_dir and window_closed(): # pragma: no cover
break
infer_from_loader(model, loader, save_dir, img_paths)

cv2.destroyAllWindows()
dataset.delete_dataset()


def infer_from_dataset(
Expand All @@ -173,33 +238,6 @@ def infer_from_dataset(
@type save_dir: str | Path | None
@param save_dir: The directory to save the visualizations to.
"""
broken = False
for i, (inputs, labels) in enumerate(model.pytorch_loaders[view]):
if broken: # pragma: no cover
break

images = get_denormalized_images(model.cfg, inputs)
batch_size = images.shape[0]
outputs = model.lightning_module.forward(
inputs, labels, images=images, compute_visualizations=True
)
renders = process_visualizations(
outputs.visualizations,
batch_size=batch_size,
)
for j in range(batch_size):
for (node_name, viz_name), visualizations in renders.items():
viz = visualizations[j]
if save_dir is not None:
name = f"{node_name}_{viz_name}"
cv2.imwrite(
str(save_dir / f"{name}_{i * batch_size + j}.png"), viz
)
else:
cv2.imshow(f"{node_name}/{viz_name}", viz)

if not save_dir and window_closed(): # pragma: no cover
broken = True
break

cv2.destroyAllWindows()
loader = model.pytorch_loaders[view]
infer_from_loader(model, loader, save_dir)
16 changes: 16 additions & 0 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,22 @@ def test_step(
"""Performs one step of testing with provided batch."""
return self._evaluation_step("test", test_batch)

def predict_step(
self, batch: tuple[dict[str, Tensor], Labels]
) -> LuxonisOutput:
"""Performs one step of prediction with provided batch."""
inputs, labels = batch
images = get_denormalized_images(self.cfg, inputs)
outputs = self.forward(
inputs,
labels,
images=images,
compute_visualizations=True,
compute_loss=False,
compute_metrics=False,
)
return outputs

def on_train_epoch_end(self) -> None:
"""Performs train epoch end operations."""
epoch_train_losses = self._average_losses(self.training_step_outputs)
Expand Down

0 comments on commit f845922

Please sign in to comment.