From f9108138fd8b7efdf28ff5fddb697d57675b58fb Mon Sep 17 00:00:00 2001 From: Max Gamill Date: Tue, 10 Dec 2024 12:40:52 +0000 Subject: [PATCH 1/4] adds ability to import older topostats unet models --- topostats/grains.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/topostats/grains.py b/topostats/grains.py index 2e2ac4cd61..64fcdfd8f1 100644 --- a/topostats/grains.py +++ b/topostats/grains.py @@ -17,6 +17,7 @@ from topostats.logs.logs import LOGGER_NAME from topostats.thresholds import threshold from topostats.unet_masking import ( + iou_loss, make_bounding_box_square, mean_iou, pad_bounding_box, @@ -611,7 +612,9 @@ def improve_grain_segmentation_unet( # You may also get an error referencing a "group_1" parameter, this is discussed in this issue: # https://github.com/keras-team/keras/issues/19441 which also has an experimental fix that we can try but # I haven't tested it yet. - unet_model = keras.models.load_model(unet_config["model_path"], custom_objects={"mean_iou": mean_iou}) + unet_model = keras.models.load_model( + unet_config["model_path"], custom_objects={"mean_iou": mean_iou, "iou_loss": iou_loss}, compile=False + ) LOGGER.debug(f"Output shape of UNet model: {unet_model.output_shape}") # Initialise an empty mask to iteratively add to for each grain, with the correct number of class channels based on From 24d93dca83923b5e9883e5a6b01d01097bdb28d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:11:54 +0000 Subject: [PATCH 2/4] [pre-commit.ci] Fixing issues with pre-commit --- topostats/grains.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/topostats/grains.py b/topostats/grains.py index 362b3ede11..1e3cd5c53b 100644 --- a/topostats/grains.py +++ b/topostats/grains.py @@ -660,7 +660,7 @@ def improve_grain_segmentation_unet( try: unet_model = keras.models.load_model( - unet_config["model_path"], custom_objects={"mean_iou": mean_iou, "iou_loss": iou_loss}, compile=False + unet_config["model_path"], custom_objects={"mean_iou": mean_iou, "iou_loss": iou_loss}, compile=False ) except Exception as e: LOGGER.debug(f"Python executable: {sys.executable}") From 6988cb60168d928463ce4822d9a5847442ce535f Mon Sep 17 00:00:00 2001 From: Max Gamill Date: Tue, 10 Dec 2024 15:20:11 +0000 Subject: [PATCH 3/4] re(?)-add mean_iou func as import --- topostats/grains.py | 1 + 1 file changed, 1 insertion(+) diff --git a/topostats/grains.py b/topostats/grains.py index 1e3cd5c53b..8ed3a42cac 100644 --- a/topostats/grains.py +++ b/topostats/grains.py @@ -21,6 +21,7 @@ from topostats.unet_masking import ( iou_loss, make_bounding_box_square, + mean_iou, pad_bounding_box, predict_unet, ) From ba118973e70bb7de6d6789f53c8a064cce77e4ea Mon Sep 17 00:00:00 2001 From: Neil Shephard Date: Wed, 11 Dec 2024 11:01:26 +0000 Subject: [PATCH 4/4] Update topostats/unet_masking.py --- topostats/unet_masking.py | 1 - 1 file changed, 1 deletion(-) diff --git a/topostats/unet_masking.py b/topostats/unet_masking.py index 79bc55e427..7700fbc36f 100644 --- a/topostats/unet_masking.py +++ b/topostats/unet_masking.py @@ -18,7 +18,6 @@ # pylint: disable=too-many-locals -# DICE Loss def dice_loss(y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32], smooth: float = 1e-5) -> tf.Tensor: """ DICE loss function.