diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index 56937115..d4731988 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -83,5 +83,7 @@ def compute(self) -> tuple[Tensor, dict[str, Tensor]]: ) map = metric_dict.pop("map") - + # WARNING: fix DDP pl.log error + map = map.to(self.device) + metric_dict = {k: v.to(self.device) for k, v in metric_dict.items()} return map, metric_dict diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py index 198005e6..6a6440df 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py @@ -254,17 +254,36 @@ def compute(self) -> tuple[Tensor, dict[str, Tensor]]: self.coco_eval.summarize() stats = self.coco_eval.stats - kpt_map = torch.tensor([stats[0]], dtype=torch.float32) + device = self.pred_keypoints[0].device + kpt_map = torch.tensor([stats[0]], dtype=torch.float32, device=device) return kpt_map, { - "kpt_map_50": torch.tensor([stats[1]], dtype=torch.float32), - "kpt_map_75": torch.tensor([stats[2]], dtype=torch.float32), - "kpt_map_medium": torch.tensor([stats[3]], dtype=torch.float32), - "kpt_map_large": torch.tensor([stats[4]], dtype=torch.float32), - "kpt_mar": torch.tensor([stats[5]], dtype=torch.float32), - "kpt_mar_50": torch.tensor([stats[6]], dtype=torch.float32), - "kpt_mar_75": torch.tensor([stats[7]], dtype=torch.float32), - "kpt_mar_medium": torch.tensor([stats[8]], dtype=torch.float32), - "kpt_mar_large": torch.tensor([stats[9]], dtype=torch.float32), + "kpt_map_50": torch.tensor( + [stats[1]], dtype=torch.float32, device=device + ), + "kpt_map_75": torch.tensor( + [stats[2]], dtype=torch.float32, device=device + ), + "kpt_map_medium": torch.tensor( + [stats[3]], dtype=torch.float32, device=device + ), + "kpt_map_large": torch.tensor( + [stats[4]], dtype=torch.float32, device=device + ), + "kpt_mar": torch.tensor( + [stats[5]], dtype=torch.float32, device=device + ), + "kpt_mar_50": torch.tensor( + [stats[6]], dtype=torch.float32, device=device + ), + "kpt_mar_75": torch.tensor( + [stats[7]], dtype=torch.float32, device=device + ), + "kpt_mar_medium": torch.tensor( + [stats[8]], dtype=torch.float32, device=device + ), + "kpt_mar_large": torch.tensor( + [stats[9]], dtype=torch.float32, device=device + ), } def _get_coco_format(