Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables legacy unets to run #1042

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions tests/test_unet_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from topostats.unet_masking import dice_loss, iou_loss, make_bounding_box_square, pad_bounding_box, predict_unet

# pylint: disable=too-many-positional-arguments


@pytest.mark.parametrize(
("y_true", "y_pred", "smooth", "expected_loss"),
Expand Down Expand Up @@ -157,10 +159,11 @@ def test_predict_unet(mock_model_5_by_5_single_class: MagicMock) -> None:
pytest.param(2, 4, 8, 8, (10, 10), (2, 3, 8, 9), id="free space double min row decrease"),
pytest.param(4, 4, 8, 6, (10, 10), (4, 3, 8, 7), id="free space double max col increase"),
pytest.param(4, 4, 6, 8, (10, 10), (3, 4, 7, 8), id="free space double max row increase"),
pytest.param(1, 1, 6, 2, (10, 10), (1, 1, 6, 6), id="constrained left"),
pytest.param(1, 6, 7, 8, (10, 10), (1, 2, 7, 8), id="constrained right"),
pytest.param(1, 1, 2, 6, (10, 10), (1, 1, 6, 6), id="constrained top"),
pytest.param(6, 1, 8, 7, (10, 10), (2, 1, 8, 7), id="constrained bottom"),
pytest.param(1, 1, 6, 2, (10, 10), (1, 0, 6, 5), id="constrained left"),
pytest.param(1, 6, 7, 8, (10, 10), (1, 3, 7, 9), id="constrained right"),
pytest.param(1, 1, 2, 6, (10, 10), (0, 1, 5, 6), id="constrained top"),
pytest.param(6, 1, 8, 7, (10, 10), (3, 1, 9, 7), id="constrained bottom"),
pytest.param(117, 20, 521, 603, (608, 608), (24, 20, 607, 603), id="constrained top and bottom"),
],
)
def test_make_bounding_box_square(
Expand All @@ -173,6 +176,14 @@ def test_make_bounding_box_square(
) -> None:
"""Test the make_bounding_box_square method."""
result = make_bounding_box_square(crop_min_row, crop_min_col, crop_max_row, crop_max_col, image_shape)
# check bbox within image bounds
assert result[0] >= 0
assert result[1] >= 0
assert result[2] <= image_shape[0]
assert result[3] <= image_shape[1]
# check if square
assert (result[2] - result[0]) == (result[3] - result[1])
# check bbox coords match
assert result == expected_indices


Expand Down
28 changes: 17 additions & 11 deletions topostats/grains.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from topostats.logs.logs import LOGGER_NAME
from topostats.thresholds import threshold
from topostats.unet_masking import (
iou_loss,
make_bounding_box_square,
mean_iou,
pad_bounding_box,
predict_unet,
)
Expand All @@ -33,6 +35,7 @@
# pylint: disable=too-many-arguments
# pylint: disable=bare-except
# pylint: disable=dangerous-default-value
# pylint: disable=too-many-positional-arguments
# pylint: disable=too-many-lines
# pylint: disable=too-many-public-methods

Expand Down Expand Up @@ -657,11 +660,13 @@ def improve_grain_segmentation_unet(
# I haven't tested it yet.

try:
unet_model = keras.models.load_model(unet_config["model_path"], compile=False)
unet_model = keras.models.load_model(
unet_config["model_path"], custom_objects={"mean_iou": mean_iou, "iou_loss": iou_loss}, compile=False
)
except Exception as e:
LOGGER.info(f"Python executable: {sys.executable}")
LOGGER.info(f"Keras version: {keras.__version__}")
LOGGER.info(f"Model path: {unet_config['model_path']}")
LOGGER.debug(f"Python executable: {sys.executable}")
LOGGER.debug(f"Keras version: {keras.__version__}")
LOGGER.debug(f"Model path: {unet_config['model_path']}")
raise e

# unet_model = keras.models.load_model(unet_config["model_path"], custom_objects={"mean_iou": mean_iou})
Expand Down Expand Up @@ -699,13 +704,14 @@ def improve_grain_segmentation_unet(
)

# Make the bounding box square within the confines of the image
bounding_box = make_bounding_box_square(
crop_min_row=bounding_box[0],
crop_min_col=bounding_box[1],
crop_max_row=bounding_box[2],
crop_max_col=bounding_box[3],
image_shape=(image.shape[0], image.shape[1]),
)
if (bounding_box[2] - bounding_box[0]) != (bounding_box[3] - bounding_box[1]):
bounding_box = make_bounding_box_square(
crop_min_row=bounding_box[0],
crop_min_col=bounding_box[1],
crop_max_row=bounding_box[2],
crop_max_col=bounding_box[3],
image_shape=(image.shape[0], image.shape[1]),
)

# Grab the cropped image. Using slice since the bounding box from skimage is
# half-open, so the max_row and max_col are not included in the region.
Expand Down
53 changes: 25 additions & 28 deletions topostats/unet_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

LOGGER = logging.getLogger(LOGGER_NAME)

# pylint: disable=too-many-positional-arguments
# pylint: disable=too-many-locals


# DICE Loss
ns-rse marked this conversation as resolved.
Show resolved Hide resolved
def dice_loss(y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32], smooth: float = 1e-5) -> tf.Tensor:
Expand Down Expand Up @@ -199,7 +202,7 @@ def predict_unet(
# Resize the channel mask to the original image size, but we want boolean so use nearest neighbour
# Sylvia: Pylint incorrectly thinks that Image.NEAREST is not a member of Image. IDK why.
# pylint: disable=no-member
channel_mask_PIL = channel_mask_PIL.resize((original_image.shape[0], original_image.shape[1]), Image.NEAREST)
channel_mask_PIL = channel_mask_PIL.resize((original_image.shape[1], original_image.shape[0]), Image.NEAREST)
resized_predicted_mask[:, :, channel_index] = np.array(channel_mask_PIL).astype(bool)

return resized_predicted_mask
Expand Down Expand Up @@ -245,45 +248,39 @@ def make_bounding_box_square(
if crop_min_col - diff // 2 >= 0 and crop_max_col + diff - diff // 2 < image_shape[1]:
new_crop_min_col = crop_min_col - diff // 2
new_crop_max_col = crop_max_col + diff - diff // 2
# If we can't expand uniformly, expand in just one direction
# Check if we can expand right
elif crop_max_col + diff - diff // 2 < image_shape[1]:
# We can expand right
new_crop_min_col = crop_min_col
new_crop_max_col = crop_max_col + diff
elif crop_min_col - diff // 2 >= 0:
# We can expand left
new_crop_min_col = crop_min_col - diff
new_crop_max_col = crop_max_col
# If we can't expand uniformly, expand as much as possible in that dir
else:
# Crop expansion below 0
if crop_min_col - diff // 2 <= 0:
new_crop_min_col = 0
new_crop_max_col = crop_max_col + (diff - crop_min_col)
# Crop expansion beyond image size
else:
new_crop_max_col = image_shape[1] - 1
new_crop_min_col = crop_min_col - (diff - (image_shape[1] - 1 - crop_max_col))
# Set the new crop height to the original crop height since we are just updating the width
new_crop_min_row = crop_min_row
new_crop_max_row = crop_max_row
elif crop_width > crop_height:
else:
# The crop is wider than it is tall
diff = crop_width - crop_height
# Check if we can expand equally in each direction
if crop_min_row - diff // 2 >= 0 and crop_max_row + diff - diff // 2 < image_shape[0]:
new_crop_min_row = crop_min_row - diff // 2
new_crop_max_row = crop_max_row + diff - diff // 2
# If we can't expand uniformly, expand in just one direction
# Check if we can expand down
elif crop_max_row + diff - diff // 2 < image_shape[0]:
# We can expand down
new_crop_min_row = crop_min_row
new_crop_max_row = crop_max_row + diff
elif crop_min_row - diff // 2 >= 0:
# We can expand up
new_crop_min_row = crop_min_row - diff
new_crop_max_row = crop_max_row
# If we can't expand uniformly, expand as much as possible in that dir
else:
# Crop expansion below 0
if crop_min_row - diff // 2 <= 0:
new_crop_min_row = 0
new_crop_max_row = crop_max_row + (diff - crop_min_row)
# Crop expansion beyond image size
else:
new_crop_max_row = image_shape[0] - 1
new_crop_min_row = crop_min_row - (diff - (image_shape[0] - 1 - crop_max_row))
# Set the new crop width to the original crop width since we are just updating the height
new_crop_min_col = crop_min_col
new_crop_max_col = crop_max_col
else:
# If the crop is already square, return the original crop
new_crop_min_row = crop_min_row
new_crop_min_col = crop_min_col
new_crop_max_row = crop_max_row
new_crop_max_col = crop_max_col

return new_crop_min_row, new_crop_min_col, new_crop_max_row, new_crop_max_col

Expand Down
Loading