diff --git a/tests/test_unet_masking.py b/tests/test_unet_masking.py index d1280ae63a..a47907072a 100644 --- a/tests/test_unet_masking.py +++ b/tests/test_unet_masking.py @@ -10,6 +10,8 @@ from topostats.unet_masking import dice_loss, iou_loss, make_bounding_box_square, pad_bounding_box, predict_unet +# pylint: disable=too-many-positional-arguments + @pytest.mark.parametrize( ("y_true", "y_pred", "smooth", "expected_loss"), @@ -149,18 +151,29 @@ def test_predict_unet(mock_model_5_by_5_single_class: MagicMock) -> None: ("crop_min_row", "crop_min_col", "crop_max_row", "crop_max_col", "image_shape", "expected_indices"), [ pytest.param(0, 0, 100, 100, (100, 100), (0, 0, 100, 100), id="not a crop"), - pytest.param(3, 4, 8, 8, (10, 10), (3, 4, 8, 9), id="free space single min col decrease"), - pytest.param(4, 3, 8, 8, (10, 10), (4, 3, 9, 8), id="free space single min row decrease"), - pytest.param(4, 4, 7, 8, (10, 10), (4, 4, 8, 8), id="free space single max col increase"), - pytest.param(4, 4, 8, 7, (10, 10), (4, 4, 8, 8), id="free space single max row increase"), - pytest.param(4, 2, 8, 8, (10, 10), (3, 2, 9, 8), id="free space double min col decrease"), - pytest.param(2, 4, 8, 8, (10, 10), (2, 3, 8, 9), id="free space double min row decrease"), - pytest.param(4, 4, 8, 6, (10, 10), (4, 3, 8, 7), id="free space double max col increase"), - pytest.param(4, 4, 6, 8, (10, 10), (3, 4, 7, 8), id="free space double max row increase"), - pytest.param(1, 1, 6, 2, (10, 10), (1, 1, 6, 6), id="constrained left"), - pytest.param(1, 6, 7, 8, (10, 10), (1, 2, 7, 8), id="constrained right"), - pytest.param(1, 1, 2, 6, (10, 10), (1, 1, 6, 6), id="constrained top"), - pytest.param(6, 1, 8, 7, (10, 10), (2, 1, 8, 7), id="constrained bottom"), + pytest.param(3, 4, 8, 8, (10, 10), (3, 4, 8, 9), id="free space single max col increase"), + pytest.param(4, 3, 8, 8, (10, 10), (4, 3, 9, 8), id="free space single max row increase"), + pytest.param(4, 4, 7, 8, (10, 10), (4, 4, 8, 8), id="free space single max row increase"), + pytest.param(4, 4, 8, 7, (10, 10), (4, 4, 8, 8), id="free space single max col increase"), + pytest.param(4, 2, 8, 8, (10, 10), (3, 2, 9, 8), id="free space double min row decrease, max row increase"), + pytest.param(2, 4, 8, 8, (10, 10), (2, 3, 8, 9), id="free space double min col decrease, max col increase"), + pytest.param(4, 4, 8, 6, (10, 10), (4, 3, 8, 7), id="free space double min col decrease, max col increase"), + pytest.param(4, 4, 6, 8, (10, 10), (3, 4, 7, 8), id="free space double min row decrease, max row increase"), + pytest.param(1, 1, 6, 2, (10, 10), (1, 0, 6, 5), id="constrained left"), + pytest.param(1, 6, 7, 8, (10, 10), (1, 3, 7, 9), id="constrained right"), + pytest.param(1, 1, 2, 6, (10, 10), (0, 1, 5, 6), id="constrained top"), + pytest.param(6, 1, 8, 7, (10, 10), (3, 1, 9, 7), id="constrained bottom"), + pytest.param(117, 20, 521, 603, (608, 608), (24, 20, 607, 603), id="constrained top and bottom"), + pytest.param( + 1, + 1, + 2, + 4, + (2, 4), + (0, 1, 2, 4), + id="rectangular image with large rectangular box", + marks=pytest.mark.xfail(reason="square bounding box not possible"), + ), ], ) def test_make_bounding_box_square( @@ -171,8 +184,16 @@ def test_make_bounding_box_square( image_shape: tuple[int, int], expected_indices: tuple[int, int, int, int], ) -> None: - """Test the make_bounding_box_square method.""" + """Test the make_bounding_box_square method returns square objects within image shape with known coordinates.""" result = make_bounding_box_square(crop_min_row, crop_min_col, crop_max_row, crop_max_col, image_shape) + # check bbox within image bounds + assert result[0] >= 0 + assert result[1] >= 0 + assert result[2] <= image_shape[0] + assert result[3] <= image_shape[1] + # check if square + assert (result[2] - result[0]) == (result[3] - result[1]) + # check bbox coords match assert result == expected_indices diff --git a/topostats/grains.py b/topostats/grains.py index 159c8f6e8e..9b7f6798da 100644 --- a/topostats/grains.py +++ b/topostats/grains.py @@ -33,6 +33,7 @@ # pylint: disable=too-many-arguments # pylint: disable=bare-except # pylint: disable=dangerous-default-value +# pylint: disable=too-many-positional-arguments # pylint: disable=too-many-lines # pylint: disable=too-many-public-methods @@ -699,13 +700,14 @@ def improve_grain_segmentation_unet( ) # Make the bounding box square within the confines of the image - bounding_box = make_bounding_box_square( - crop_min_row=bounding_box[0], - crop_min_col=bounding_box[1], - crop_max_row=bounding_box[2], - crop_max_col=bounding_box[3], - image_shape=(image.shape[0], image.shape[1]), - ) + if (bounding_box[2] - bounding_box[0]) != (bounding_box[3] - bounding_box[1]): + bounding_box = make_bounding_box_square( + crop_min_row=bounding_box[0], + crop_min_col=bounding_box[1], + crop_max_row=bounding_box[2], + crop_max_col=bounding_box[3], + image_shape=(image.shape[0], image.shape[1]), + ) # Grab the cropped image. Using slice since the bounding box from skimage is # half-open, so the max_row and max_col are not included in the region. diff --git a/topostats/unet_masking.py b/topostats/unet_masking.py index a9734fe83b..79bc55e427 100644 --- a/topostats/unet_masking.py +++ b/topostats/unet_masking.py @@ -14,6 +14,9 @@ LOGGER = logging.getLogger(LOGGER_NAME) +# pylint: disable=too-many-positional-arguments +# pylint: disable=too-many-locals + # DICE Loss def dice_loss(y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32], smooth: float = 1e-5) -> tf.Tensor: @@ -199,7 +202,7 @@ def predict_unet( # Resize the channel mask to the original image size, but we want boolean so use nearest neighbour # Sylvia: Pylint incorrectly thinks that Image.NEAREST is not a member of Image. IDK why. # pylint: disable=no-member - channel_mask_PIL = channel_mask_PIL.resize((original_image.shape[0], original_image.shape[1]), Image.NEAREST) + channel_mask_PIL = channel_mask_PIL.resize((original_image.shape[1], original_image.shape[0]), Image.NEAREST) resized_predicted_mask[:, :, channel_index] = np.array(channel_mask_PIL).astype(bool) return resized_predicted_mask @@ -245,45 +248,39 @@ def make_bounding_box_square( if crop_min_col - diff // 2 >= 0 and crop_max_col + diff - diff // 2 < image_shape[1]: new_crop_min_col = crop_min_col - diff // 2 new_crop_max_col = crop_max_col + diff - diff // 2 - # If we can't expand uniformly, expand in just one direction - # Check if we can expand right - elif crop_max_col + diff - diff // 2 < image_shape[1]: - # We can expand right - new_crop_min_col = crop_min_col - new_crop_max_col = crop_max_col + diff - elif crop_min_col - diff // 2 >= 0: - # We can expand left - new_crop_min_col = crop_min_col - diff - new_crop_max_col = crop_max_col + # If we can't expand uniformly, expand as much as possible in that dir + else: + # Crop expansion below 0 + if crop_min_col - diff // 2 <= 0: + new_crop_min_col = 0 + new_crop_max_col = crop_max_col + (diff - crop_min_col) + # Crop expansion beyond image size + else: + new_crop_max_col = image_shape[1] - 1 + new_crop_min_col = crop_min_col - (diff - (image_shape[1] - 1 - crop_max_col)) # Set the new crop height to the original crop height since we are just updating the width new_crop_min_row = crop_min_row new_crop_max_row = crop_max_row - elif crop_width > crop_height: + else: # The crop is wider than it is tall diff = crop_width - crop_height # Check if we can expand equally in each direction if crop_min_row - diff // 2 >= 0 and crop_max_row + diff - diff // 2 < image_shape[0]: new_crop_min_row = crop_min_row - diff // 2 new_crop_max_row = crop_max_row + diff - diff // 2 - # If we can't expand uniformly, expand in just one direction - # Check if we can expand down - elif crop_max_row + diff - diff // 2 < image_shape[0]: - # We can expand down - new_crop_min_row = crop_min_row - new_crop_max_row = crop_max_row + diff - elif crop_min_row - diff // 2 >= 0: - # We can expand up - new_crop_min_row = crop_min_row - diff - new_crop_max_row = crop_max_row + # If we can't expand uniformly, expand as much as possible in that dir + else: + # Crop expansion below 0 + if crop_min_row - diff // 2 <= 0: + new_crop_min_row = 0 + new_crop_max_row = crop_max_row + (diff - crop_min_row) + # Crop expansion beyond image size + else: + new_crop_max_row = image_shape[0] - 1 + new_crop_min_row = crop_min_row - (diff - (image_shape[0] - 1 - crop_max_row)) # Set the new crop width to the original crop width since we are just updating the height new_crop_min_col = crop_min_col new_crop_max_col = crop_max_col - else: - # If the crop is already square, return the original crop - new_crop_min_row = crop_min_row - new_crop_min_col = crop_min_col - new_crop_max_row = crop_max_row - new_crop_max_col = crop_max_col return new_crop_min_row, new_crop_min_col, new_crop_max_row, new_crop_max_col