diff --git a/topostats/grains.py b/topostats/grains.py index 9b7f6798da..8ed3a42cac 100644 --- a/topostats/grains.py +++ b/topostats/grains.py @@ -19,7 +19,9 @@ from topostats.logs.logs import LOGGER_NAME from topostats.thresholds import threshold from topostats.unet_masking import ( + iou_loss, make_bounding_box_square, + mean_iou, pad_bounding_box, predict_unet, ) @@ -658,11 +660,13 @@ def improve_grain_segmentation_unet( # I haven't tested it yet. try: - unet_model = keras.models.load_model(unet_config["model_path"], compile=False) + unet_model = keras.models.load_model( + unet_config["model_path"], custom_objects={"mean_iou": mean_iou, "iou_loss": iou_loss}, compile=False + ) except Exception as e: - LOGGER.info(f"Python executable: {sys.executable}") - LOGGER.info(f"Keras version: {keras.__version__}") - LOGGER.info(f"Model path: {unet_config['model_path']}") + LOGGER.debug(f"Python executable: {sys.executable}") + LOGGER.debug(f"Keras version: {keras.__version__}") + LOGGER.debug(f"Model path: {unet_config['model_path']}") raise e # unet_model = keras.models.load_model(unet_config["model_path"], custom_objects={"mean_iou": mean_iou}) diff --git a/topostats/unet_masking.py b/topostats/unet_masking.py index 79bc55e427..7700fbc36f 100644 --- a/topostats/unet_masking.py +++ b/topostats/unet_masking.py @@ -18,7 +18,6 @@ # pylint: disable=too-many-locals -# DICE Loss def dice_loss(y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32], smooth: float = 1e-5) -> tf.Tensor: """ DICE loss function.