Skip to content

Commit

Permalink
fix: seg loss, batch vis, and mAP for seg
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Dec 10, 2024
1 parent 6b7710e commit 0f842e1
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 201 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(
**kwargs: Any,
):
"""BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8
<https://arxiv.org/pdf/2305.09972>}
<https://arxiv.org/pdf/2305.09972>} and from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications
<https://arxiv.org/pdf/2209.02976.pdf>}.
Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}.
@type reg_max: int
@param reg_max: Maximum number of regression channels. Defaults to 16.
Expand Down
168 changes: 55 additions & 113 deletions luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def __init__(
class_loss_weight: float = 0.5,
bbox_loss_weight: float = 7.5,
dfl_loss_weight: float = 1.5,
overlap_mask: bool = True,
**kwargs: Any,
):
"""Instance Segmentation and BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8
<https://arxiv.org/pdf/2305.09972>}
<https://arxiv.org/pdf/2305.09972>} and from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications
<https://arxiv.org/pdf/2209.02976.pdf>}.
Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}.
@type reg_max: int
@param reg_max: Maximum number of regression channels. Defaults to 16.
Expand All @@ -59,7 +60,6 @@ def __init__(
dfl_loss_weight=dfl_loss_weight,
**kwargs,
)
self.overlap = overlap_mask

def prepare(
self, inputs: Packet[Tensor], labels: Labels
Expand All @@ -73,7 +73,7 @@ def prepare(
[xi.view(self.batch_size, self.node.no, -1) for xi in det_feats], 2
).split((self.node.reg_max * 4, self.n_classes), 1)
target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX)
img_idx = target_bbox[:, 0]
img_idx = target_bbox[:, 0].unsqueeze(-1)
target_masks = self.get_label(labels, TaskType.INSTANCE_SEGMENTATION)
if tuple(target_masks.shape[-2:]) != (mask_h, mask_w):
target_masks = F.interpolate(
Expand Down Expand Up @@ -153,15 +153,14 @@ def forward(
loss_iou = torch.tensor(0.0).to(pred_distri.device)
loss_dfl = torch.tensor(0.0).to(pred_distri.device)

loss_seg = self.calculate_segmentation_loss(
loss_seg = self.compute_segmentation_loss(
mask_positive,
target_masks,
assigned_gt_idx,
assigned_bboxes,
img_idx,
proto,
pred_masks,
self.overlap,
)

loss = (
Expand All @@ -179,122 +178,65 @@ def forward(

return loss, sub_losses

# TODO: Modify after adding corect annotation loading
def calculate_segmentation_loss(
def compute_segmentation_loss(
self,
fg_mask: torch.Tensor,
masks: torch.Tensor,
target_gt_idx: torch.Tensor,
target_bboxes: torch.Tensor,
batch_idx: torch.Tensor,
gt_masks: torch.Tensor,
gt_idx: torch.Tensor,
bboxes: torch.Tensor,
batch_ids: torch.Tensor,
proto: torch.Tensor,
pred_masks: torch.Tensor,
overlap: bool,
) -> torch.Tensor:
"""Calculate the loss for instance segmentation.
Args:
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
overlap (bool): Whether the masks in `masks` tensor overlap.
Returns:
(torch.Tensor): The calculated loss for instance segmentation.
Notes:
The batch loss can be computed for improved speed at higher memory usage.
For example, pred_mask can be computed as follows:
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
"""Compute the segmentation loss for the entire batch.
@type fg_mask: torch.Tensor
@param fg_mask: Foreground mask. Shape: (B, N_anchor).
@type gt_masks: torch.Tensor
@param gt_masks: Ground truth masks. Shape: (n, H, W).
@type gt_idx: torch.Tensor
@param gt_idx: Ground truth mask indices. Shape: (B, N_anchor).
@type bboxes: torch.Tensor
@param bboxes: Ground truth bounding boxes in xyxy format.
Shape: (B, N_anchor, 4).
@type batch_ids: torch.Tensor
@param batch_ids: Batch indices. Shape: (n, 1).
@type proto: torch.Tensor
@param proto: Prototype masks. Shape: (B, 32, H, W).
@type pred_masks: torch.Tensor
@param pred_masks: Predicted mask coefficients. Shape: (B,
N_anchor, 32).
"""
_, _, mask_h, mask_w = proto.shape
loss = 0

# Normalize to 0-1
target_bboxes_normalized = target_bboxes / self.gt_bboxes_scale

# Areas of target bboxes
marea = box_convert(
target_bboxes_normalized, in_fmt="xyxy", out_fmt="xywh"
)[..., 2:].prod(2)

# Normalize to mask size
mxyxy = target_bboxes_normalized * torch.tensor(
[mask_w, mask_h, mask_w, mask_h], device=proto.device
_, _, h, w = proto.shape
total_loss = 0
bboxes_norm = bboxes / self.gt_bboxes_scale
bbox_area = box_convert(bboxes_norm, in_fmt="xyxy", out_fmt="xywh")[
..., 2:
].prod(2)
bboxes_scaled = bboxes_norm * torch.tensor(
[w, h, w, h], device=proto.device
)

for i, single_i in enumerate(
zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)
for img_idx, data in enumerate(
zip(fg_mask, gt_idx, pred_masks, proto, bboxes_scaled, bbox_area)
):
(
fg_mask_i,
target_gt_idx_i,
pred_masks_i,
proto_i,
mxyxy_i,
marea_i,
masks_i,
) = single_i
if fg_mask_i.any():
mask_idx = target_gt_idx_i[fg_mask_i]
if overlap:
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
gt_mask = gt_mask.float()
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]

loss += self.single_mask_loss(
gt_mask,
pred_masks_i[fg_mask_i],
proto_i,
mxyxy_i[fg_mask_i],
marea_i[fg_mask_i],
fg, gt, pred, pr, bbox, area = data
if fg.any():
mask_ids = gt[fg]
gt_mask = gt_masks[batch_ids.view(-1) == img_idx][mask_ids]

# Compute individual image mask loss
pred_mask = torch.einsum("in,nhw->ihw", pred[fg], pr)
loss = F.binary_cross_entropy_with_logits(
pred_mask, gt_mask, reduction="none"
)

# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
total_loss += (
apply_bounding_box_to_masks(loss, bbox[fg]).mean(
dim=(1, 2)
)
/ area[fg]
).sum()
else:
loss += (proto * 0).sum() + (
pred_masks * 0
).sum() # inf sums may lead to nan loss

return loss / fg_mask.sum()
total_loss += (proto * 0).sum() + (pred_masks * 0).sum()

# TODO: Modify after adding corect annotation loading
@staticmethod
def single_mask_loss(
gt_mask: torch.Tensor,
pred: torch.Tensor,
proto: torch.Tensor,
xyxy: torch.Tensor,
area: torch.Tensor,
) -> torch.Tensor:
"""Compute the instance segmentation loss for a single image.
Args:
gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
proto (torch.Tensor): Prototype masks of shape (32, H, W).
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
Returns:
(torch.Tensor): The calculated mask loss for a single image.
Notes:
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
predicted masks from the prototype masks and predicted mask coefficients.
"""
pred_mask = torch.einsum(
"in,nhw->ihw", pred, proto
) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
loss = F.binary_cross_entropy_with_logits(
pred_mask, gt_mask, reduction="none"
)
return (
apply_bounding_box_to_masks(loss, xyxy).mean(dim=(1, 2)) / area
).sum()
return total_loss / fg_mask.sum()
Loading

0 comments on commit 0f842e1

Please sign in to comment.