Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifies the squaring bbox and tweaks PIL input shapes #1030

Merged
merged 7 commits into from
Dec 11, 2024
47 changes: 34 additions & 13 deletions tests/test_unet_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from topostats.unet_masking import dice_loss, iou_loss, make_bounding_box_square, pad_bounding_box, predict_unet

# pylint: disable=too-many-positional-arguments


@pytest.mark.parametrize(
("y_true", "y_pred", "smooth", "expected_loss"),
Expand Down Expand Up @@ -149,18 +151,29 @@ def test_predict_unet(mock_model_5_by_5_single_class: MagicMock) -> None:
("crop_min_row", "crop_min_col", "crop_max_row", "crop_max_col", "image_shape", "expected_indices"),
[
pytest.param(0, 0, 100, 100, (100, 100), (0, 0, 100, 100), id="not a crop"),
pytest.param(3, 4, 8, 8, (10, 10), (3, 4, 8, 9), id="free space single min col decrease"),
pytest.param(4, 3, 8, 8, (10, 10), (4, 3, 9, 8), id="free space single min row decrease"),
pytest.param(4, 4, 7, 8, (10, 10), (4, 4, 8, 8), id="free space single max col increase"),
pytest.param(4, 4, 8, 7, (10, 10), (4, 4, 8, 8), id="free space single max row increase"),
pytest.param(4, 2, 8, 8, (10, 10), (3, 2, 9, 8), id="free space double min col decrease"),
pytest.param(2, 4, 8, 8, (10, 10), (2, 3, 8, 9), id="free space double min row decrease"),
pytest.param(4, 4, 8, 6, (10, 10), (4, 3, 8, 7), id="free space double max col increase"),
pytest.param(4, 4, 6, 8, (10, 10), (3, 4, 7, 8), id="free space double max row increase"),
pytest.param(1, 1, 6, 2, (10, 10), (1, 1, 6, 6), id="constrained left"),
pytest.param(1, 6, 7, 8, (10, 10), (1, 2, 7, 8), id="constrained right"),
pytest.param(1, 1, 2, 6, (10, 10), (1, 1, 6, 6), id="constrained top"),
pytest.param(6, 1, 8, 7, (10, 10), (2, 1, 8, 7), id="constrained bottom"),
pytest.param(3, 4, 8, 8, (10, 10), (3, 4, 8, 9), id="free space single max col increase"),
pytest.param(4, 3, 8, 8, (10, 10), (4, 3, 9, 8), id="free space single max row increase"),
pytest.param(4, 4, 7, 8, (10, 10), (4, 4, 8, 8), id="free space single max row increase"),
pytest.param(4, 4, 8, 7, (10, 10), (4, 4, 8, 8), id="free space single max col increase"),
pytest.param(4, 2, 8, 8, (10, 10), (3, 2, 9, 8), id="free space double min row decrease, max row increase"),
pytest.param(2, 4, 8, 8, (10, 10), (2, 3, 8, 9), id="free space double min col decrease, max col increase"),
pytest.param(4, 4, 8, 6, (10, 10), (4, 3, 8, 7), id="free space double min col decrease, max col increase"),
pytest.param(4, 4, 6, 8, (10, 10), (3, 4, 7, 8), id="free space double min row decrease, max row increase"),
pytest.param(1, 1, 6, 2, (10, 10), (1, 0, 6, 5), id="constrained left"),
pytest.param(1, 6, 7, 8, (10, 10), (1, 3, 7, 9), id="constrained right"),
pytest.param(1, 1, 2, 6, (10, 10), (0, 1, 5, 6), id="constrained top"),
pytest.param(6, 1, 8, 7, (10, 10), (3, 1, 9, 7), id="constrained bottom"),
pytest.param(117, 20, 521, 603, (608, 608), (24, 20, 607, 603), id="constrained top and bottom"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator

@ns-rse ns-rse Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I've tried looking at this in detail I find the dimensions quite hard to think about, are there smaller image_shape and co-ordinates that would be constrained top and bottom?

I think if image_shape is always square, which isn't necessarily the case since I recall @SylviaWhittle writing code to handle none-square images, then we'd only ever see a constraint on one side. Something that would be constrained top and bottom would be a wide rectangular bounding box within a wide rectangular image.

EDIT: I added just such a set of test parameters to the above mentioned commit and marked it as pytest.mark.xfail() with a reason explaining why. If cherry picking it will be included.

pytest.param(
1,
1,
2,
4,
(2, 4),
(0, 1, 2, 4),
id="rectangular image with large rectangular box",
marks=pytest.mark.xfail(reason="square bounding box not possible"),
),
],
)
def test_make_bounding_box_square(
Expand All @@ -171,8 +184,16 @@ def test_make_bounding_box_square(
image_shape: tuple[int, int],
expected_indices: tuple[int, int, int, int],
) -> None:
"""Test the make_bounding_box_square method."""
"""Test the make_bounding_box_square method returns square objects within image shape with known coordinates."""
result = make_bounding_box_square(crop_min_row, crop_min_col, crop_max_row, crop_max_col, image_shape)
# check bbox within image bounds
assert result[0] >= 0
assert result[1] >= 0
assert result[2] <= image_shape[0]
assert result[3] <= image_shape[1]
# check if square
assert (result[2] - result[0]) == (result[3] - result[1])
# check bbox coords match
assert result == expected_indices


Expand Down
16 changes: 9 additions & 7 deletions topostats/grains.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# pylint: disable=too-many-arguments
# pylint: disable=bare-except
# pylint: disable=dangerous-default-value
# pylint: disable=too-many-positional-arguments
# pylint: disable=too-many-lines
# pylint: disable=too-many-public-methods

Expand Down Expand Up @@ -699,13 +700,14 @@ def improve_grain_segmentation_unet(
)

# Make the bounding box square within the confines of the image
bounding_box = make_bounding_box_square(
crop_min_row=bounding_box[0],
crop_min_col=bounding_box[1],
crop_max_row=bounding_box[2],
crop_max_col=bounding_box[3],
image_shape=(image.shape[0], image.shape[1]),
)
if (bounding_box[2] - bounding_box[0]) != (bounding_box[3] - bounding_box[1]):
bounding_box = make_bounding_box_square(
crop_min_row=bounding_box[0],
crop_min_col=bounding_box[1],
crop_max_row=bounding_box[2],
crop_max_col=bounding_box[3],
image_shape=(image.shape[0], image.shape[1]),
)

# Grab the cropped image. Using slice since the bounding box from skimage is
# half-open, so the max_row and max_col are not included in the region.
Expand Down
53 changes: 25 additions & 28 deletions topostats/unet_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

LOGGER = logging.getLogger(LOGGER_NAME)

# pylint: disable=too-many-positional-arguments
# pylint: disable=too-many-locals


# DICE Loss
def dice_loss(y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32], smooth: float = 1e-5) -> tf.Tensor:
Expand Down Expand Up @@ -199,7 +202,7 @@ def predict_unet(
# Resize the channel mask to the original image size, but we want boolean so use nearest neighbour
# Sylvia: Pylint incorrectly thinks that Image.NEAREST is not a member of Image. IDK why.
# pylint: disable=no-member
channel_mask_PIL = channel_mask_PIL.resize((original_image.shape[0], original_image.shape[1]), Image.NEAREST)
channel_mask_PIL = channel_mask_PIL.resize((original_image.shape[1], original_image.shape[0]), Image.NEAREST)
resized_predicted_mask[:, :, channel_index] = np.array(channel_mask_PIL).astype(bool)

return resized_predicted_mask
Expand Down Expand Up @@ -245,45 +248,39 @@ def make_bounding_box_square(
if crop_min_col - diff // 2 >= 0 and crop_max_col + diff - diff // 2 < image_shape[1]:
new_crop_min_col = crop_min_col - diff // 2
new_crop_max_col = crop_max_col + diff - diff // 2
# If we can't expand uniformly, expand in just one direction
# Check if we can expand right
elif crop_max_col + diff - diff // 2 < image_shape[1]:
# We can expand right
new_crop_min_col = crop_min_col
new_crop_max_col = crop_max_col + diff
elif crop_min_col - diff // 2 >= 0:
# We can expand left
new_crop_min_col = crop_min_col - diff
new_crop_max_col = crop_max_col
# If we can't expand uniformly, expand as much as possible in that dir
else:
# Crop expansion below 0
if crop_min_col - diff // 2 <= 0:
new_crop_min_col = 0
new_crop_max_col = crop_max_col + (diff - crop_min_col)
# Crop expansion beyond image size
else:
new_crop_max_col = image_shape[1] - 1
new_crop_min_col = crop_min_col - (diff - (image_shape[1] - 1 - crop_max_col))
# Set the new crop height to the original crop height since we are just updating the width
new_crop_min_row = crop_min_row
new_crop_max_row = crop_max_row
elif crop_width > crop_height:
else:
# The crop is wider than it is tall
diff = crop_width - crop_height
# Check if we can expand equally in each direction
if crop_min_row - diff // 2 >= 0 and crop_max_row + diff - diff // 2 < image_shape[0]:
new_crop_min_row = crop_min_row - diff // 2
new_crop_max_row = crop_max_row + diff - diff // 2
# If we can't expand uniformly, expand in just one direction
# Check if we can expand down
elif crop_max_row + diff - diff // 2 < image_shape[0]:
# We can expand down
new_crop_min_row = crop_min_row
new_crop_max_row = crop_max_row + diff
elif crop_min_row - diff // 2 >= 0:
# We can expand up
new_crop_min_row = crop_min_row - diff
new_crop_max_row = crop_max_row
# If we can't expand uniformly, expand as much as possible in that dir
else:
# Crop expansion below 0
if crop_min_row - diff // 2 <= 0:
new_crop_min_row = 0
new_crop_max_row = crop_max_row + (diff - crop_min_row)
# Crop expansion beyond image size
else:
new_crop_max_row = image_shape[0] - 1
new_crop_min_row = crop_min_row - (diff - (image_shape[0] - 1 - crop_max_row))
# Set the new crop width to the original crop width since we are just updating the height
new_crop_min_col = crop_min_col
new_crop_max_col = crop_max_col
else:
# If the crop is already square, return the original crop
new_crop_min_row = crop_min_row
new_crop_min_col = crop_min_col
new_crop_max_row = crop_max_row
new_crop_max_col = crop_max_col

return new_crop_min_row, new_crop_min_col, new_crop_max_row, new_crop_max_col

Expand Down
Loading