From 2964afc7315bc7579d78eafd6db8ad24ea3d26ac Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Mon, 9 Dec 2024 11:23:50 +0100 Subject: [PATCH] Creator tests fix. --- .../ml/messages/creators/detection.py | 1 - .../test_classification_sequence.py | 7 +-- .../test_creators/test_classifications.py | 5 -- tests/unittests/test_creators/test_image.py | 43 +++-------------- .../unittests/test_creators/test_keypoints.py | 47 +++++++------------ .../test_creators/test_line_detections.py | 20 +++----- tests/unittests/test_creators/test_map.py | 8 ++-- .../test_creators/test_regression.py | 9 +--- .../test_creators/test_segmentation.py | 6 +-- .../test_creators/test_tracked_features.py | 25 +++------- 10 files changed, 46 insertions(+), 125 deletions(-) diff --git a/depthai_nodes/ml/messages/creators/detection.py b/depthai_nodes/ml/messages/creators/detection.py index 02a0601d..06acd7b2 100644 --- a/depthai_nodes/ml/messages/creators/detection.py +++ b/depthai_nodes/ml/messages/creators/detection.py @@ -55,7 +55,6 @@ def create_detection_message( @raise ValueError: If the keypoints are not a numpy array of shape (N, M, 2 or 3). @raise ValueError: If the masks are not a 3D numpy array of shape (img_height, img_width, N) or (N, img_height, img_width). - @raise ValueError: If the masks are not in the range [0, 1]. @raise ValueError: If the keypoints scores are not a numpy array. @raise ValueError: If the keypoints scores are not of shape [n_detections, n_keypoints, 1]. diff --git a/tests/unittests/test_creators/test_classification_sequence.py b/tests/unittests/test_creators/test_classification_sequence.py index d10c1944..126a4a11 100644 --- a/tests/unittests/test_creators/test_classification_sequence.py +++ b/tests/unittests/test_creators/test_classification_sequence.py @@ -46,11 +46,6 @@ def test_1d_scores(): create_classification_sequence_message(CLASSES, [0.5, 0.2, 0.3]) -def test_invalid_scores(): - with pytest.raises(ValueError): - create_classification_sequence_message(CLASSES, [0.7, 0.2, 0.1]) - - def test_mismatched_lengths(): with pytest.raises(ValueError): create_classification_sequence_message(CLASSES, [[0.7, 0.2], [0.1, 0.8]]) @@ -83,7 +78,7 @@ def test_integer_ignored_indexes(): create_classification_sequence_message(CLASSES, SCORES, ignored_indexes=[1.0]) -def test_2D_list_integers(): +def test_2D_list_ignored_integers(): with pytest.raises(ValueError): create_classification_sequence_message(CLASSES, SCORES, ignored_indexes=[[3]]) diff --git a/tests/unittests/test_creators/test_classifications.py b/tests/unittests/test_creators/test_classifications.py index 3ce353d4..7b29ea9c 100644 --- a/tests/unittests/test_creators/test_classifications.py +++ b/tests/unittests/test_creators/test_classifications.py @@ -60,11 +60,6 @@ def test_very_small_scores(): ) -def test_none_both(): - with pytest.raises(ValueError): - create_classification_message(None, None) - - def test_none_classes(): with pytest.raises(ValueError): create_classification_message(None, SCORES) diff --git a/tests/unittests/test_creators/test_image.py b/tests/unittests/test_creators/test_image.py index b24c8b6d..b7237866 100644 --- a/tests/unittests/test_creators/test_image.py +++ b/tests/unittests/test_creators/test_image.py @@ -17,35 +17,21 @@ def test_valid_hwc_bgr(): assert img_frame.getWidth() == 640 assert img_frame.getHeight() == 480 assert img_frame.getType() == dai.ImgFrame.Type.BGR888i + assert np.array_equal(img_frame.getCvFrame(), IMAGE) def test_valid_hwc_rgb(): - img_frame = create_image_message(IMAGE, is_bgr=False) - - assert isinstance(img_frame, dai.ImgFrame) - assert img_frame.getWidth() == 640 - assert img_frame.getHeight() == 480 - assert img_frame.getType() == dai.ImgFrame.Type.BGR888i + create_image_message(IMAGE, is_bgr=False) def test_valid_chw_bgr(): image = IMAGE.transpose(2, 0, 1) - img_frame = create_image_message(image, is_bgr=True) - - assert isinstance(img_frame, dai.ImgFrame) - assert img_frame.getWidth() == 640 - assert img_frame.getHeight() == 480 - assert img_frame.getType() == dai.ImgFrame.Type.BGR888i + create_image_message(image, is_bgr=True) def test_valid_chw_rgb(): image = IMAGE.transpose(2, 0, 1) - img_frame = create_image_message(image, is_bgr=False) - - assert isinstance(img_frame, dai.ImgFrame) - assert img_frame.getWidth() == 640 - assert img_frame.getHeight() == 480 - assert img_frame.getType() == dai.ImgFrame.Type.BGR888i + create_image_message(image, is_bgr=False) def test_valid_hwc_grayscale(): @@ -59,31 +45,16 @@ def test_valid_hwc_grayscale(): def test_valid_chw_grayscale(): image = IMAGE_GRAY.transpose(2, 0, 1) - img_frame = create_image_message(image, is_bgr=True) - - assert isinstance(img_frame, dai.ImgFrame) - assert img_frame.getWidth() == 640 - assert img_frame.getHeight() == 480 - assert img_frame.getType() == dai.ImgFrame.Type.GRAY8 + create_image_message(image, is_bgr=True) def test_invalid_shape(): image = np.random.randint(0, 256, (480, 640, 4), dtype=np.uint8) - with pytest.raises(ValueError, match="Unexpected image shape. Expected CHW or HWC"): + with pytest.raises(ValueError): create_image_message(image, is_bgr=True) def test_invalid_dtype(): image = IMAGE.astype(np.float32) - with pytest.raises( - ValueError, match="Expected int type, got ." - ): + with pytest.raises(ValueError): create_image_message(image, is_bgr=True) - - -def test_float_array(): - img = np.array([[[0.5, 0.5, 0.5]]]) - with pytest.raises( - ValueError, match="Expected int type, got ." - ): - create_image_message(img) diff --git a/tests/unittests/test_creators/test_keypoints.py b/tests/unittests/test_creators/test_keypoints.py index 1dacb90b..e6bf9470 100644 --- a/tests/unittests/test_creators/test_keypoints.py +++ b/tests/unittests/test_creators/test_keypoints.py @@ -41,76 +41,61 @@ def test_valid_keypoints_no_scores(): assert isinstance(message, Keypoints) assert len(message.keypoints) == 2 assert all(isinstance(kp, Keypoint) for kp in message.keypoints) - for i, kp in enumerate(message.keypoints): - assert kp.x == KPTS[i][0] - assert kp.y == KPTS[i][1] - assert kp.z == 0.0 + for kp in message.keypoints: assert kp.confidence == -1 def test_invalid_keypoints_type(): - with pytest.raises( - ValueError, match="Keypoints should be numpy array or list, got ." - ): + with pytest.raises(ValueError): create_keypoints_message("not a list or array") def test_invalid_scores_type(): - with pytest.raises( - ValueError, match="Scores should be numpy array or list, got ." - ): + with pytest.raises(ValueError): create_keypoints_message(KPTS, scores="not a list or array") def test_mismatched_keypoints_scores_length(): scores = [0.9] - with pytest.raises( - ValueError, match="Keypoints and scores should have the same length." - ): + with pytest.raises(ValueError): create_keypoints_message(KPTS, scores) def test_invalid_scores_values(): scores = [0.9, "not a float"] - with pytest.raises(ValueError, match="Scores should only contain float values."): + with pytest.raises(ValueError): create_keypoints_message(KPTS, scores) def test_scores_out_of_range(): scores = [0.9, 1.1] - with pytest.raises( - ValueError, match="Scores should only contain values between 0 and 1." - ): + with pytest.raises(ValueError): create_keypoints_message(KPTS, scores) def test_invalid_confidence_threshold_type(): - with pytest.raises( - ValueError, match="The confidence_threshold should be float, got ." - ): + with pytest.raises(ValueError): create_keypoints_message(KPTS, SCORES, confidence_threshold="not a float") def test_confidence_threshold_out_of_range(): - with pytest.raises( - ValueError, match="The confidence_threshold should be between 0 and 1." - ): + with pytest.raises(ValueError): create_keypoints_message(KPTS, SCORES, confidence_threshold=1.1) def test_invalid_keypoints_shape(): keypoints = [[0.1, 0.2, 0.3, 0.4]] - with pytest.raises( - ValueError, - match="All keypoints should be of dimension 2 or 3, got dimension 4.", - ): + with pytest.raises(ValueError): create_keypoints_message(keypoints) def test_invalid_keypoints_inner_type(): keypoints = [[0.1, "not a float"]] - with pytest.raises( - ValueError, - match="Keypoints inner list should contain only float, got .", - ): + with pytest.raises(ValueError): + create_keypoints_message(keypoints) + + +def test_all_keypoints_same_shape(): + keypoints = [[0.1, 0.2], [0.3, 0.4, 0.5]] + with pytest.raises(ValueError): create_keypoints_message(keypoints) diff --git a/tests/unittests/test_creators/test_line_detections.py b/tests/unittests/test_creators/test_line_detections.py index f0678d3a..51469317 100644 --- a/tests/unittests/test_creators/test_line_detections.py +++ b/tests/unittests/test_creators/test_line_detections.py @@ -35,41 +35,35 @@ def test_empty_lines(): def test_invalid_lines_type(): - with pytest.raises( - ValueError, match="Lines should be numpy array, got ." - ): + with pytest.raises(ValueError): create_line_detection_message(LINE.tolist(), SCORE) def test_invalid_lines_shape(): - with pytest.raises(ValueError, match="Lines should be of shape"): + with pytest.raises(ValueError): create_line_detection_message(LINE[0], SCORE) def test_invalid_lines_dimension(): - with pytest.raises(ValueError, match="Lines 2nd dimension should be of size 4"): + with pytest.raises(ValueError): create_line_detection_message(LINE[:, :3], SCORE) def test_invalid_scores_type(): - with pytest.raises( - ValueError, match="Scores should be numpy array, got ." - ): + with pytest.raises(ValueError): create_line_detection_message(LINE, SCORE.tolist()) def test_invalid_scores_shape(): - with pytest.raises(ValueError, match="Scores should be of shape"): + with pytest.raises(ValueError): create_line_detection_message(LINE, np.array([SCORE])) def test_invalid_scores_value_type(): - with pytest.raises( - ValueError, match="Scores should be of type float, got ." - ): + with pytest.raises(ValueError): create_line_detection_message(LINE, np.array([1], dtype=np.int64)) def test_mismatched_lines_scores_length(): - with pytest.raises(ValueError, match="Scores should have same length as lines"): + with pytest.raises(ValueError): create_line_detection_message(LINE, np.array([0.9, 0.8])) diff --git a/tests/unittests/test_creators/test_map.py b/tests/unittests/test_creators/test_map.py index d2af5e32..80df4f8b 100644 --- a/tests/unittests/test_creators/test_map.py +++ b/tests/unittests/test_creators/test_map.py @@ -46,19 +46,17 @@ def test_min_max_scaling(): def test_invalid_type(): - with pytest.raises(ValueError, match="Expected numpy array, got ."): + with pytest.raises(ValueError): create_map_message(MAP_ARRAY.tolist()) def test_invalid_shape(): - with pytest.raises(ValueError, match="Expected 2D or 3D input, got 1D input."): + with pytest.raises(ValueError): create_map_message(np.array([0.1, 0.2, 0.3, 0.4])) def test_invalid_3d_shape(): - with pytest.raises( - ValueError, match="Unexpected map shape. Expected NHW or HWN, got" - ): + with pytest.raises(ValueError): create_map_message(np.random.rand(2, 480, 640).astype(np.float32)) diff --git a/tests/unittests/test_creators/test_regression.py b/tests/unittests/test_creators/test_regression.py index b4728961..e2bb57c7 100644 --- a/tests/unittests/test_creators/test_regression.py +++ b/tests/unittests/test_creators/test_regression.py @@ -25,16 +25,11 @@ def test_empty_list(): def test_invalid_type(): - with pytest.raises( - ValueError, match="Predictions should be list, got ." - ): + with pytest.raises(ValueError): create_regression_message("not a list") def test_invalid_prediction_type(): predictions = [0.1, "not a float", 0.3] - with pytest.raises( - ValueError, - match="Each prediction should be a float, got instead.", - ): + with pytest.raises(ValueError): create_regression_message(predictions) diff --git a/tests/unittests/test_creators/test_segmentation.py b/tests/unittests/test_creators/test_segmentation.py index 9cd37364..0b21fcbd 100644 --- a/tests/unittests/test_creators/test_segmentation.py +++ b/tests/unittests/test_creators/test_segmentation.py @@ -16,18 +16,18 @@ def test_valid_input(): def test_invalid_type(): - with pytest.raises(ValueError, match="Expected numpy array, got ."): + with pytest.raises(ValueError): create_segmentation_message(MASK.tolist()) def test_invalid_shape(): mask = np.random.randint(0, 256, (480, 640, 3), dtype=np.int16) - with pytest.raises(ValueError, match="Expected 2D input, got 3D input."): + with pytest.raises(ValueError): create_segmentation_message(mask) def test_invalid_dtype(): - with pytest.raises(ValueError, match="Expected int16 input, got uint8."): + with pytest.raises(ValueError): create_segmentation_message(MASK.astype(np.uint8)) diff --git a/tests/unittests/test_creators/test_tracked_features.py b/tests/unittests/test_creators/test_tracked_features.py index 41b5229c..1feffaa5 100644 --- a/tests/unittests/test_creators/test_tracked_features.py +++ b/tests/unittests/test_creators/test_tracked_features.py @@ -36,49 +36,38 @@ def test_valid_input(): def test_invalid_reference_points_type(): - with pytest.raises( - ValueError, match="reference_points should be numpy array, got ." - ): + with pytest.raises(ValueError): create_tracked_features_message(REFERENCE_POINTS.tolist(), TARGET_POINTS) def test_invalid_reference_points_shape(): - with pytest.raises(ValueError, match="reference_points should be of shape"): + with pytest.raises(ValueError): create_tracked_features_message(REFERENCE_POINTS.flatten(), TARGET_POINTS) def test_invalid_reference_points_dimension(): reference_points = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) - with pytest.raises( - ValueError, match="reference_points 2nd dimension should be of size 2" - ): + with pytest.raises(ValueError): create_tracked_features_message(reference_points, TARGET_POINTS) def test_invalid_target_points_type(): - with pytest.raises( - ValueError, match="target_points should be numpy array, got ." - ): + with pytest.raises(ValueError): create_tracked_features_message(REFERENCE_POINTS, TARGET_POINTS.tolist()) def test_invalid_target_points_shape(): - with pytest.raises(ValueError, match="target_points should be of shape"): + with pytest.raises(ValueError): create_tracked_features_message(REFERENCE_POINTS, TARGET_POINTS.flatten()) def test_invalid_target_points_dimension(): target_points = np.array([[0.5, 0.6, 0.7], [0.8, 0.9, 1.0]]) - with pytest.raises( - ValueError, match="target_points 2nd dimension should be of size 2" - ): + with pytest.raises(ValueError): create_tracked_features_message(REFERENCE_POINTS, target_points) def test_mismatched_points_length(): target_points = np.array([[0.5, 0.6]]) - with pytest.raises( - ValueError, - match="The number of reference points and target points should be the same.", - ): + with pytest.raises(ValueError): create_tracked_features_message(REFERENCE_POINTS, target_points)