From 645b3e0ffd30c99b9970f05ad1de4cd5e17f66d9 Mon Sep 17 00:00:00 2001 From: HonzaCuhel Date: Wed, 28 Aug 2024 22:22:01 +0200 Subject: [PATCH] Update Pose head --- tools/modules/heads.py | 22 ++++++++++++---------- tools/yolo/yolov8_exporter.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tools/modules/heads.py b/tools/modules/heads.py index 8e92df9..f92ce05 100644 --- a/tools/modules/heads.py +++ b/tools/modules/heads.py @@ -16,9 +16,10 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y sy, sx = torch.meshgrid(sy, sx, indexing="ij") - anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) - stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) - return torch.cat(anchor_points), torch.cat(stride_tensor) + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2).transpose(0, 1)) + stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device).transpose(0, 1)) + return anchor_points, stride_tensor + # return torch.cat(anchor_points), torch.cat(stride_tensor) class DetectV5(nn.Module): @@ -432,26 +433,27 @@ def forward(self, x): """Perform forward pass through YOLO model and return predictions.""" bs = x[0].shape[0] # batch size if self.shape != bs: - self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.anchors, self.strides = make_anchors(x, self.stride, 0.5) self.shape = bs # Detection part outputs = super().forward(x) # Pose part - kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w) - pred_kpt = self.kpts_decode(bs, kpt) - outputs.append(pred_kpt) + for i in range(self.nl): + kpt = self.cv4[i](x[i]).view(bs, self.nk, -1) + outputs.append(self.kpts_decode(bs, kpt, i)) return outputs - def kpts_decode(self, bs, kpts): + def kpts_decode(self, bs, kpts, i): """Decodes keypoints.""" ndim = self.kpt_shape[1] y = kpts.view(bs, *self.kpt_shape, -1) - a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides + a = (y[:, :, :2] * 2.0 + (self.anchors[i] - 0.5)) * self.strides[i] if ndim == 3: - a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2) + # a = torch.cat((a, y[:, :, 2:3].sigmoid()*10), 2) + a = torch.cat((a, y[:, :, 2:3]), 2) return a.view(bs, self.nk, -1) diff --git a/tools/yolo/yolov8_exporter.py b/tools/yolo/yolov8_exporter.py index 108088d..bfc3b6a 100644 --- a/tools/yolo/yolov8_exporter.py +++ b/tools/yolo/yolov8_exporter.py @@ -44,7 +44,7 @@ def get_output_names(mode: int) -> List[str]: elif mode == OBB_MODE: return ["output1_yolov8", "output2_yolov8", "output3_yolov8", "angle_output"] elif mode == POSE_MODE: - return ["output1_yolov8", "output2_yolov8", "output3_yolov8", "kpt_output"] + return ["output1_yolov8", "output2_yolov8", "output3_yolov8", "kpt_output1", "kpt_output2", "kpt_output3"] return ["output"]