diff --git a/analytics/shorelines.ipynb b/analytics/shorelines.ipynb new file mode 100644 index 0000000..03eed4b --- /dev/null +++ b/analytics/shorelines.ipynb @@ -0,0 +1,917 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Cloud-native coastal waterline mapping\n", + "\n", + "NOTE: Since the Planetary Computer Hub was retired this workflow is broken. Hopefully it will be fixed soon. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": { + "tags": [], + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "branch = \"dev\"\n", + "sys.path.insert(0, \"../src\")\n", + "from odc.stac import configure_rio\n", + "\n", + "from coastmonitor.io.drive_config import configure_instance\n", + "\n", + "is_local_instance = configure_instance(branch=branch)\n", + "configure_rio(cloud_defaults=True)\n", + "\n", + "import logging\n", + "import os\n", + "import time\n", + "import warnings\n", + "\n", + "import dask\n", + "\n", + "dask.config.set({\"dataframe.query-planning\": False})\n", + "import dask.array as da\n", + "import dask.dataframe as dd\n", + "import dask_geopandas\n", + "import geopandas as gpd\n", + "import hvplot.pandas\n", + "import hvplot.xarray # noqa # noqa # noqa\n", + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "from astropy.convolution import convolve\n", + "from dask import delayed\n", + "from dask.distributed import performance_report\n", + "from dotenv import load_dotenv\n", + "from geopandas.array import GeometryDtype\n", + "from scipy import ndimage\n", + "from skimage import filters, morphology\n", + "\n", + "from coastmonitor.dask_utils import generate_geometry_mask\n", + "from coastmonitor.geo.quadtiles import make_mercantiles\n", + "from coastmonitor.io.cloud import list_storage_location, write_block, write_table\n", + "from coastmonitor.io.eo import load_sentinel2_data\n", + "from coastmonitor.io.load import (\n", + " infer_region_of_interest,\n", + " retrieve_coastsat_classifier,\n", + " retrieve_rois,\n", + " retrieve_s2_tiles,\n", + ")\n", + "from coastmonitor.io.utils import name_block, name_table\n", + "from coastmonitor.transform.transform import (\n", + " add_indices_preserve_nodata,\n", + " mask_by_classes,\n", + " mask_invalid_values,\n", + ")\n", + "from coastmonitor.xarray_utils import extract_and_set_nodata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": { + "tags": [], + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "### Input params ###\n", + "DATE_RANGE = \"2015-06-23/2024-11-01\"\n", + "ROI = \"NARRABEEN\" # NOTE: or bbox, like: (4.169655, 52.047265, 4.209394, 52.068902)\n", + "CLOUD_COVER = {\"lt\": 10}\n", + "INDICES = [\"NDWI\", \"MNDWI\", \"NDVI\", \"NDMI\", \"BR\"]\n", + "BANDS = [\"blue\", \"green\", \"red\", \"nir\", \"swir16\", \"SCL\"]\n", + "OVERWRITE = True\n", + "\n", + "PROFILE_OPTIONS = {\n", + " \"driver\": \"COG\",\n", + " \"dtype\": \"uint8\",\n", + " \"compress\": \"LZW\",\n", + "}\n", + "MERCANTILES_ZOOM_LEVEL = 10\n", + "\n", + "SCL_CLASSES = {\n", + " 0: \"No Data\",\n", + " 1: \"Saturated / Defective\",\n", + " 2: \"Dark Area Pixels\",\n", + " 3: \"Cloud Shadows\",\n", + " 4: \"Vegetation\",\n", + " 5: \"Bare Soils\",\n", + " 6: \"water\",\n", + " 7: \"Clouds low probability / Unclassified\",\n", + " 8: \"Clouds medium probability\",\n", + " 9: \"Clouds high probability\",\n", + " 10: \"Cirrus\",\n", + " 11: \"Snow / Ice\",\n", + "}\n", + "\n", + "SCL_CLASSES_TO_MASK = [\n", + " \"No Data\",\n", + " \"Dark Area Pixels\",\n", + " \"Clouds high probability\",\n", + " \"Cirrus\",\n", + " # \"Snow / Ice\", # do not mask because whitewater is often classified as ice\n", + "]\n", + "\n", + "start_date_range, end_date_range = DATE_RANGE.split(\"/\")\n", + "wop_storage_prefix = f\"az://wop/{start_date_range}_to_{end_date_range}\"\n", + "shoreline_storage_prefix = f\"az://shorelines/{start_date_range}_to_{end_date_range}\"\n", + "\n", + "### NOTE: env vars\n", + "load_dotenv(override=True)\n", + "\n", + "sas_token = os.getenv(\"AZURE_STORAGE_SAS_TOKEN\")\n", + "storage_account_name = os.getenv(\"AZURE_STORAGE_ACCOUNT_NAME\")\n", + "storage_options = {\"account_name\": storage_account_name, \"sas_token\": sas_token}\n", + "gh_coastmonitor_token = os.getenv(\"GH_COASTMONITOR_TOKEN\")\n", + "\n", + "# Useful for fetching data on Planetary Computer\n", + "# NOTE: this should now be convered by odc.stac configure rio\n", + "# os.environ[\"GDAL_HTTP_MAX_RETRY\"] = \"3\"\n", + "\n", + "# Adjust logging level for azure and rasterio\n", + "logging.getLogger(\"azure\").setLevel(logging.WARNING)\n", + "logging.getLogger(\"rasterio\").setLevel(logging.WARNING)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": { + "tags": [], + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "def clean_noise(\n", + " image: np.ndarray, structure_size: int = 6, min_object_size: int = 100\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Apply morphological operations to clean noise from a binary image.\n", + "\n", + " Args:\n", + " image (np.ndarray): Input binary array (image).\n", + " structure_size (int, optional): Size of the square structuring element for morphological operations. Defaults to 6.\n", + " min_object_size (int, optional): Minimum size of objects to retain in the image. Defaults to 100.\n", + "\n", + " Returns:\n", + " np.ndarray: The cleaned image.\n", + "\n", + " Example:\n", + " >>> image = np.random.randint(0, 2, (100, 100), dtype=bool)\n", + " >>> cleaned_image = clean_noise_from_image(image, structure_size=5, min_object_size=50)\n", + " \"\"\"\n", + " structure = morphology.square(structure_size)\n", + " binary_opening = ndimage.binary_opening(image, structure=structure)\n", + " image = morphology.remove_small_objects(\n", + " binary_opening, min_size=min_object_size, connectivity=1\n", + " )\n", + " return image\n", + "\n", + "\n", + "def standard_deviation(\n", + " image: np.ndarray, radius: int, nodata: float | int = np.nan\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Calculates the standard deviation of an image using a moving window of\n", + " specified radius with astropy's convolution library.\n", + "\n", + " Args:\n", + " image (np.ndarray): 2D array containing the pixel intensities of a single-band image.\n", + " radius (int): Radius defining the moving window used to calculate the standard deviation.\n", + " For example, radius = 1 will produce a 3x3 moving window.\n", + " nodata (float, optional): Value to replace NaN results with. Defaults to np.nan.\n", + "\n", + " Returns:\n", + " np.ndarray: 2D array containing the standard deviation of the image.\n", + "\n", + " Example:\n", + " >>> img = np.random.random((100, 100))\n", + " >>> std_img = standard_deviation(img, radius=1)\n", + " \"\"\"\n", + "\n", + " # Create kernel once\n", + " win_rows, win_cols = radius * 2 + 1, radius * 2 + 1\n", + " kernel = np.ones((win_rows, win_cols))\n", + "\n", + " # Pre-calculate square of image\n", + " image_sq = image**2\n", + "\n", + " # First pad the image and its square\n", + " image_padded = np.pad(image, radius, \"reflect\")\n", + " image_sq_padded = np.pad(image_sq, radius, \"reflect\")\n", + "\n", + " # Calculate std with uniform filters\n", + " win_mean = convolve(\n", + " image_padded,\n", + " kernel,\n", + " boundary=\"extend\",\n", + " normalize_kernel=True,\n", + " nan_treatment=\"interpolate\",\n", + " preserve_nan=True,\n", + " )\n", + " win_sqr_mean = convolve(\n", + " image_sq_padded,\n", + " kernel,\n", + " boundary=\"extend\",\n", + " normalize_kernel=True,\n", + " nan_treatment=\"interpolate\",\n", + " preserve_nan=True,\n", + " )\n", + " win_var = win_sqr_mean - win_mean**2\n", + "\n", + " # Ignore RuntimeWarnings in the sqrt calculation\n", + " with warnings.catch_warnings():\n", + " warnings.filterwarnings(\"ignore\", message=\"invalid value encountered in sqrt\")\n", + " win_std = np.sqrt(win_var)\n", + "\n", + " # Remove padding\n", + " win_std = win_std[radius:-radius, radius:-radius]\n", + "\n", + " # After computing standard deviation, replace NaN values with nodata\n", + " win_std[np.isnan(win_std)] = nodata\n", + "\n", + " return win_std\n", + "\n", + "\n", + "def apply_standard_deviation(\n", + " da, input_core_dims, output_core_dims, radius=1, nodata=np.nan\n", + "):\n", + " \"\"\"\n", + " Apply the standard deviation calculation to an xarray DataArray.\n", + "\n", + " Args:\n", + " data_array (xr.DataArray): Input data.\n", + " radius (int): The radius for the moving window.\n", + " nodata (float): Value to replace NaN results with.\n", + " dim (str or list): Dimensions over which to apply the ufunc.\n", + "\n", + " Returns:\n", + " xr.DataArray: Result of the standard deviation calculation.\n", + " \"\"\"\n", + " return xr.apply_ufunc(\n", + " standard_deviation,\n", + " da,\n", + " input_core_dims=input_core_dims,\n", + " output_core_dims=output_core_dims,\n", + " vectorize=True,\n", + " dask=\"parallelized\",\n", + " kwargs={\"radius\": radius, \"nodata\": nodata},\n", + " output_dtypes=[\"f4\"],\n", + " )\n", + "\n", + "\n", + "def add_stdev_preserve_nodata(\n", + " ds: xr.Dataset,\n", + " exclude_vars: list | None = None,\n", + " nodata: float | int = np.nan,\n", + " radius: int = 1,\n", + ") -> xr.Dataset:\n", + " \"\"\"\n", + " Calculate the standard deviation for all variables in an xarray dataset, excluding specified variables.\n", + "\n", + " Args:\n", + " ds (xr.Dataset): Input dataset.\n", + " exclude_vars (list, optional): List of variable names to exclude from the standard deviation calculation.\n", + " Defaults to None.\n", + " nodata (float, optional): Value to replace NaN results with. If not provided, it will be inferred.\n", + " Defaults to None.\n", + " radius (int, optional): Radius to compute the stdev. 1 equals an kernel of 3x3.\n", + " Defaults to 1.\n", + "\n", + " Returns:\n", + " xr.Dataset: Dataset with original variables and new variables containing standard deviations.\n", + "\n", + " Example:\n", + " >>> ds = xr.Dataset({\n", + " ... \"a\": ([\"x\", \"y\"], np.random.rand(4, 3)),\n", + " ... \"b\": ([\"x\", \"y\"], np.random.rand(4, 3))\n", + " ... })\n", + " >>> result_ds = calculate_stdev_with_nodata(ds)\n", + " \"\"\"\n", + "\n", + " if exclude_vars is None:\n", + " exclude_vars = []\n", + "\n", + " ds_for_stdev = ds.drop_vars(exclude_vars)\n", + " rename_dict = {var: var + \"_std\" for var in list(ds_for_stdev.data_vars)}\n", + "\n", + " # If nodata_valu# If nodata_value is not explicitly provided, infer it\n", + " if nodata is None:\n", + " _, nodata = extract_and_set_nodata(ds, list(rename_dict.keys()), [])\n", + "\n", + " ds_stdev = apply_standard_deviation(\n", + " ds_for_stdev,\n", + " input_core_dims=[[\"y\", \"x\"]],\n", + " output_core_dims=[[\"y\", \"x\"]],\n", + " radius=radius,\n", + " nodata=nodata,\n", + " )\n", + "\n", + " # Construct new variable names and rename them\n", + " ds_stdev = ds_stdev.rename(rename_dict)\n", + "\n", + " # Merge original and new data\n", + " ds = xr.merge([ds, ds_stdev])\n", + "\n", + " ds, _ = extract_and_set_nodata(\n", + " ds, list(rename_dict.keys()), list(rename_dict.values())\n", + " )\n", + "\n", + " return ds\n", + "\n", + "\n", + "def classify_image(arr, classifier):\n", + " output_shape = arr.shape[:2]\n", + "\n", + " arr = arr.reshape(-1, arr.shape[-1])\n", + " nan_mask = np.isnan(arr).any(axis=1) # computes nans along features index\n", + "\n", + " # arr with nan values to store result\n", + " result = np.zeros(\n", + " arr.shape[0],\n", + " )\n", + " result[:] = np.nan\n", + "\n", + " # don't bother classifying (and avoid error's) when all values are nan\n", + " if nan_mask.all():\n", + " return result.reshape(output_shape)\n", + "\n", + " predictions = classifier.predict(arr[~nan_mask])\n", + " result[~nan_mask] = predictions\n", + " result = result.reshape(output_shape)\n", + "\n", + " # TODO: erosion/dilation\n", + " ...\n", + " return result\n", + "\n", + "\n", + "def compute_otsu_threshold(arr):\n", + " \"\"\"\n", + " Compute Otsu's threshold for a flattened array.\n", + "\n", + " Args:\n", + " arr (np.ndarray): 1D array.\n", + "\n", + " Returns:\n", + " float: The computed Otsu's threshold value.\n", + " \"\"\"\n", + " nan_mask = np.isnan(arr)\n", + " if nan_mask.all():\n", + " return np.nan\n", + " return filters.threshold_otsu(arr[~nan_mask])\n", + "\n", + "\n", + "def load_s2_shoreline_hypercube(\n", + " bbox: gpd.GeoDataFrame(),\n", + " query: dict,\n", + " date_range: str,\n", + " bands,\n", + " indices,\n", + " scl_classes_to_mask,\n", + " buffer_roi_scattered,\n", + " classifier,\n", + "):\n", + " import rioxarray # noqa\n", + "\n", + " SCL_CLASSES = {\n", + " 0: \"No Data\",\n", + " 1: \"Saturated / Defective\",\n", + " 2: \"Dark Area Pixels\",\n", + " 3: \"Cloud Shadows\",\n", + " 4: \"Vegetation\",\n", + " 5: \"Bare Soils\",\n", + " 6: \"water\",\n", + " 7: \"Clouds low probability / Unclassified\",\n", + " 8: \"Clouds medium probability\",\n", + " 9: \"Clouds high probability\",\n", + " 10: \"Cirrus\",\n", + " 11: \"Snow / Ice\",\n", + " }\n", + "\n", + " if \"SCL\" not in bands:\n", + " bands += \"SCL\"\n", + "\n", + " xx = load_sentinel2_data(bbox, date_range, query, bands)\n", + "\n", + " mask = mask_invalid_values(xx)\n", + " scl_mask = mask_by_classes(xx.SCL, scl_classes_to_mask, SCL_CLASSES)\n", + "\n", + " geom_mask = generate_geometry_mask(xx.red, buffer_roi_scattered)\n", + "\n", + " mask = mask.merge(scl_mask)\n", + " mask = mask.merge(geom_mask)\n", + "\n", + " def apply_clean_noise(da, input_core_dims, output_core_dims, output_dtypes):\n", + " return xr.apply_ufunc(\n", + " clean_noise,\n", + " da,\n", + " input_core_dims=input_core_dims,\n", + " output_core_dims=output_core_dims,\n", + " vectorize=True,\n", + " dask=\"parallelized\",\n", + " output_dtypes=output_dtypes,\n", + " )\n", + "\n", + " mask[\"SCL_mask\"] = apply_clean_noise(\n", + " mask.SCL_mask,\n", + " input_core_dims=[[\"y\", \"x\"]],\n", + " output_core_dims=[[\"y\", \"x\"]],\n", + " output_dtypes=mask.SCL_mask.dtype,\n", + " )\n", + "\n", + " xx = add_indices_preserve_nodata(\n", + " xx,\n", + " bands=bands,\n", + " indices=indices,\n", + " bands_to_rename={\"swir16\": \"swir1\"},\n", + " )\n", + "\n", + " # invert NDWI, MNDWI, NDMI to match with CoastSat's classifier\n", + " for var in [\"NDWI\", \"MNDWI\", \"NDMI\"]:\n", + " xx[var] = xx[var] * -1\n", + "\n", + " scl = xx.SCL\n", + "\n", + " xx = xx.drop_vars(\"SCL\")\n", + " xx = add_stdev_preserve_nodata(xx, exclude_vars=[], radius=1, nodata=np.nan)\n", + "\n", + " mask_ = mask.drop_vars([\"SCL_mask\"]).to_array(\"variables\").any(\"variables\")\n", + "\n", + " # create a feature array\n", + " feature_da = (\n", + " xx.where(~mask_)\n", + " .to_array(\"features\")\n", + " .chunk({\"features\": 20})\n", + " .transpose(\"time\", \"y\", \"x\", \"features\")\n", + " )\n", + "\n", + " def apply_classify_image(da, input_core_dims, output_core_dims):\n", + " return xr.apply_ufunc(\n", + " classify_image,\n", + " da,\n", + " input_core_dims=input_core_dims,\n", + " output_core_dims=output_core_dims,\n", + " vectorize=True,\n", + " dask=\"parallelized\",\n", + " output_dtypes=\"f4\",\n", + " kwargs={\"classifier\": classifier},\n", + " )\n", + "\n", + " xx[\"CLASS\"] = apply_classify_image(\n", + " feature_da,\n", + " input_core_dims=[[\"y\", \"x\", \"features\"]],\n", + " output_core_dims=[[\"y\", \"x\"]],\n", + " ).rename(\"CLASS\")\n", + "\n", + " mask = mask.merge(~xx[\"CLASS\"].isin([1, 3]).rename(\"CLASS_mask\"))\n", + "\n", + " # # keep everything together\n", + " xx = xx.merge(mask)\n", + "\n", + " # # computing otsu threshold masking out all potential noise sources\n", + " otsu_threshold_mask = (\n", + " mask.to_array(\"variable\").any(\"variable\").rename(\"otsu_threshold_mask\")\n", + " )\n", + "\n", + " shoreline_mask = (\n", + " mask.drop_vars([\"CLASS_mask\", \"SCL_mask\"])\n", + " .to_array(\"variable\")\n", + " .any(\"variable\")\n", + " .rename(\"shoreline_mask\")\n", + " )\n", + "\n", + " xx = xx.merge(otsu_threshold_mask)\n", + " xx = xx.merge(shoreline_mask)\n", + "\n", + " def apply_otsu_threshold(da, input_core_dims):\n", + " \"\"\"\n", + " Compute the global Otsu's threshold for an xarray DataArray along a specified dimension.\n", + "\n", + " Args:\n", + " data_array (xr.DataArray): Input xarray DataArray.\n", + " dim (str): Dimension along which to compute the Otsu's threshold.\n", + "\n", + " Returns:\n", + " xr.DataArray: DataArray containing the global Otsu's threshold.\n", + " \"\"\"\n", + "\n", + " return xr.apply_ufunc(\n", + " compute_otsu_threshold,\n", + " da,\n", + " input_core_dims=input_core_dims,\n", + " vectorize=True,\n", + " dask=\"parallelized\",\n", + " output_dtypes=\"f4\",\n", + " )\n", + "\n", + " t_otsu = apply_otsu_threshold(\n", + " xx.MNDWI.where(~otsu_threshold_mask).chunk({\"time\": 1}),\n", + " input_core_dims=[[\"y\", \"x\"]],\n", + " ).rename(\"otsu\")\n", + "\n", + " # Compute the percentage of water pixels in scl and coassat to filter the images\n", + " xx = xr.merge([xx, scl])\n", + "\n", + " coastsat_water_occurrence = (\n", + " (xx.CLASS == 3).sum(dim=[\"y\", \"x\"]) / xx.CLASS.count(dim=[\"y\", \"x\"]) * 100\n", + " )\n", + " valid_data_mask = (xx.SCL != 0) & ~xx.SCL.isnull()\n", + " scl_water_occurrence = (\n", + " ((xx.SCL == 6) & valid_data_mask).sum(dim=[\"y\", \"x\"])\n", + " / valid_data_mask.sum(dim=[\"y\", \"x\"])\n", + " * 100\n", + " )\n", + " xx = xx.assign_coords(\n", + " coastsat_water_occurrence=(\"time\", coastsat_water_occurrence.data)\n", + " )\n", + " xx = xx.assign_coords(scl_water_occurrence=(\"time\", scl_water_occurrence.data))\n", + " xx = xx.assign_coords(t_otsu=(\"time\", t_otsu.data))\n", + "\n", + " # NOTE: keep as ref because maybe change to otsu quality checking instead\n", + " # mask_imgs_by_otsu_threshold = (xx.coords[\"t_otsu\"] > -0.5) & (xx.coords[\"t_otsu\"] < 0.5)\n", + " mask_imgs_by_class = (\n", + " (xx.coords[\"coastsat_water_occurrence\"] > 95)\n", + " | (xx.coords[\"scl_water_occurrence\"] > 95)\n", + " ).rename(\"img_mask\")\n", + " xx[\"final_mask\"] = xx.shoreline_mask | mask_imgs_by_class\n", + "\n", + " xx[\"otsu\"] = t_otsu > xx.where(~xx.final_mask).MNDWI\n", + " return xx" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": { + "tags": [], + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "# NOTE: make a Dask Gateway or a LocalCluster depending on the instance type\n", + "if is_local_instance:\n", + " from dask.distributed import Client\n", + "\n", + " logging.info(\"Launching local client...\")\n", + " client = Client(\n", + " threads_per_worker=1,\n", + " processes=True,\n", + " local_directory=\"/tmp\",\n", + " )\n", + "\n", + " def silence_warnings():\n", + " import warnings\n", + "\n", + " warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n", + "\n", + " client.run(silence_warnings)\n", + "\n", + "else:\n", + " import dask_gateway\n", + " from distributed import PipInstall\n", + "\n", + " logging.info(\"Launching dask gateway client...\")\n", + " # NOTE: leave these params as they can be used if more memory is required\n", + " # gateway = dask_gateway.Gateway()\n", + " # cluster_options = gateway.cluster_options()\n", + " # cluster_options[\"worker_memory\"] = 16\n", + " # cluster = gateway.new_cluster(cluster_options)\n", + "\n", + " cluster = dask_gateway.GatewayCluster()\n", + " client = cluster.get_client()\n", + " cluster.adapt(minimum=2, maximum=50)\n", + " plugin = PipInstall(\n", + " [\n", + " f\"git+https://{gh_coastmonitor_token}@github.com/floriscalkoen/coastmonitor.git@{branch}\"\n", + " ]\n", + " )\n", + " client.register_plugin(plugin)\n", + " logging.info(f\"Dashboard can be accessed at: {client.dashboard_link}.\")\n", + "client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": { + "tags": [], + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "# NOTE: LOAD DATA\n", + "s2_tiles = retrieve_s2_tiles().to_crs(4326)\n", + "rois = retrieve_rois().to_crs(4326)\n", + "# TODO: write STAC catalog for the coastal buffer\n", + "buffer = dask_geopandas.read_parquet(\n", + " \"az://coastline-buffer/osm-coastlines-buffer-2000m.parquet\",\n", + " storage_options=storage_options,\n", + ").compute()\n", + "quadtiles = make_mercantiles(zoom_level=MERCANTILES_ZOOM_LEVEL).to_crs(4326)\n", + "\n", + "with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " classifier = retrieve_coastsat_classifier()\n", + "\n", + "region_of_interest = infer_region_of_interest(ROI)\n", + "\n", + "# NOTE: make an overlay with a coastline buffer to avoid querying data we do not need\n", + "buffer_aoi = gpd.overlay(buffer, region_of_interest[[\"geometry\"]].to_crs(buffer.crs))\n", + "\n", + "# TODO: add heuristic to decide which s2 tiles to use\n", + "s2_tilenames_to_process = gpd.sjoin(\n", + " s2_tiles, buffer_aoi.to_crs(s2_tiles.crs)\n", + ").Name.unique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "s2_tilenames_to_process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": { + "tags": [], + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "# NOTE: process the shorelines per Sentinel 2 tile\n", + "dask_report_fp = f\"{ROI}_dask-report.html\"\n", + "with performance_report(dask_report_fp):\n", + " start_time = time.time()\n", + " for s2_tilename in s2_tilenames_to_process:\n", + " s2_tile = (\n", + " s2_tiles.loc[s2_tiles.Name == s2_tilename][[\"geometry\"]]\n", + " .explode(index_parts=False)\n", + " .iloc[[0]]\n", + " )\n", + " bbox = tuple(s2_tile.total_bounds)\n", + " query = {\"eo:cloud_cover\": CLOUD_COVER, \"s2:mgrs_tile\": {\"eq\": s2_tilename}}\n", + "\n", + " # NOTE: the result will be reprojected to this raster. # TODO: consider if it\n", + " # would be better to do a trim_outer_nan's here.\n", + " # BUG: raster cannot be made with one band, so we add green.\n", + " template_raster = load_sentinel2_data(\n", + " bbox, \"2023-06-23/2023-11-01\", query, [\"red\"]\n", + " )\n", + " template_raster = template_raster.red.isel(time=0)\n", + "\n", + " s2_tile_by_aoi = gpd.overlay(\n", + " s2_tile, region_of_interest[[\"geometry\"]].to_crs(s2_tile.crs)\n", + " )\n", + "\n", + " # NOTE: the S2 tiles are too large to process as a whole, so here smaller tiles\n", + " # quadkeys are created to process in smaller chunks\n", + " tiles = gpd.sjoin(quadtiles, s2_tile_by_aoi).drop(columns=[\"index_right\"])\n", + " tiles = (\n", + " gpd.sjoin(tiles, buffer.to_crs(4326))\n", + " .drop(columns=[\"index_right\", \"EPSG\"])\n", + " .drop_duplicates(\"quadkey\")\n", + " )\n", + " # TODO: replace prints with tqdm\n", + " logging.info(\n", + " f\"Start processing next S2 tile: {s2_tilename} that spans\"\n", + " f\" {len(tiles)} quadtiles.\"\n", + " )\n", + "\n", + " # NOTE: process per quadkey\n", + " for _, tile in tiles.iloc[[0]].iterrows():\n", + " # NOTE: don't process data that is already processed.\n", + " if not OVERWRITE:\n", + " list_files_prefix = (\n", + " f\"{start_date_range}_to_{end_date_range}*s2={s2_tilename}*\"\n", + " )\n", + " files = list_storage_location(\n", + " f\"az://wop/{start_date_range}_to_{end_date_range}/\",\n", + " storage_options=storage_options,\n", + " prefix=list_files_prefix,\n", + " )\n", + " is_processed = any(\n", + " tile.quadkey in f and s2_tilename in f for f in files\n", + " )\n", + " if is_processed:\n", + " continue\n", + "\n", + " logging.info(f\"Start processing next quadtile: {tile.quadkey}\")\n", + "\n", + " roi = tile.to_frame().transpose().set_geometry(\"geometry\", crs=tiles.crs)\n", + " roi = gpd.overlay(roi, region_of_interest[[\"geometry\"]].to_crs(roi.crs))\n", + " roi = rois.loc[[\"NARRABEEN\"]]\n", + "\n", + " # NOTE: distribute the buffer for the region of interest to the workers\n", + " buffer_roi = gpd.overlay(buffer, roi.to_crs(buffer.crs)).dissolve()\n", + " # NOTE: in rare cases there is no overlap beteween the tile and the buffer (at boundaries ROI)\n", + " if buffer_roi.empty:\n", + " continue\n", + " buffer_roi_scattered = client.scatter(buffer_roi, broadcast=True)\n", + "\n", + " # NOTE: query datacube from STAC\n", + " bbox = tuple(roi.total_bounds)\n", + "\n", + " xx = load_s2_shoreline_hypercube(\n", + " bbox=bbox,\n", + " query=query,\n", + " date_range=DATE_RANGE,\n", + " bands=BANDS,\n", + " indices=INDICES,\n", + " scl_classes_to_mask=SCL_CLASSES_TO_MASK,\n", + " buffer_roi_scattered=buffer_roi_scattered,\n", + " classifier=classifier,\n", + " )\n", + "\n", + " xx = xx[[\"otsu\", \"MNDWI\", \"final_mask\", \"shoreline_mask\"]]\n", + "\n", + " def map_shorelines(da):\n", + " # TODO: check how to make rioxarray available to workers by default\n", + " import rioxarray # noqa\n", + "\n", + " from coastmonitor.dea_tools import subpixel_contours\n", + "\n", + " # NOTE: ruff wants to change '== false' to 'is False', which breaks the condition\n", + " ds = da.to_dataset(\"band\")\n", + " df = subpixel_contours(\n", + " ds.MNDWI.where(ds.final_mask == False).to_numpy(), # noqa: E712\n", + " z_values=ds.coords[\"t_otsu\"].to_numpy(),\n", + " crs=ds.rio.crs.to_epsg(),\n", + " affine=ds.rio.transform(),\n", + " )\n", + " df[\"time\"] = ds.coords[\"time\"].item()\n", + " return df\n", + "\n", + " META = gpd.GeoDataFrame(\n", + " {\n", + " \"z_value\": pd.Series(dtype=str),\n", + " \"geometry\": pd.Series(dtype=GeometryDtype),\n", + " \"time\": pd.Series(dtype=\"i8\"),\n", + " }\n", + " )\n", + "\n", + " dfs = []\n", + " for da in (\n", + " xx[[\"MNDWI\", \"final_mask\"]]\n", + " .to_array(\"band\")\n", + " .transpose(\"time\", \"band\", \"y\", \"x\")\n", + " ):\n", + " df = delayed(map_shorelines)(da)\n", + " dfs.append(df)\n", + " shorelines = dd.from_delayed(dfs, meta=META)\n", + "\n", + " # TODO: discuss if we need to save multiple time ranges\n", + " wop = (xx.otsu.where(~xx.final_mask)).mean(\"time\").rename(\"wop\")\n", + " nodata_value = 255\n", + " wop = wop * 100\n", + " wop = wop.where(~np.isnan(wop), nodata_value)\n", + "\n", + " wop = (\n", + " wop.astype(np.uint8)\n", + " .rio.write_nodata(nodata_value)\n", + " .rio.set_spatial_dims(x_dim=\"x\", y_dim=\"y\")\n", + " )\n", + " wop.attrs = xx.attrs\n", + "\n", + " wop, shorelines = dask.compute(*[wop, shorelines])\n", + " wop = wop.rio.reproject_match(template_raster, nodata=nodata_value)\n", + " # NOTE: the outer nan's can be trimmed, but to create tiles with consistent\n", + " # shape they are currenlty included when writing the results to the cloud container.\n", + " # trimmed_wop = trim_outer_nans(wop,nodata=nodata_value).astype(np.uint8)\n", + " wop_href = name_block(\n", + " wop,\n", + " storage_prefix=wop_storage_prefix,\n", + " name_prefix=f\"s2={s2_tilename}_qk={tile.quadkey}\",\n", + " )\n", + "\n", + " write_block(\n", + " wop,\n", + " wop_href,\n", + " storage_options=storage_options,\n", + " profile_options=PROFILE_OPTIONS,\n", + " )\n", + "\n", + " shorelines = shorelines.to_crs(4326)\n", + " shorelines_href = name_table(\n", + " shorelines,\n", + " storage_prefix=shoreline_storage_prefix,\n", + " name_prefix=f\"s2={s2_tilename}_qk={tile.quadkey}\",\n", + " )\n", + " write_table(shorelines, shorelines_href, storage_options=storage_options)\n", + "\n", + " logging.info(\"Done!\")\n", + " elapsed_time = time.time() - start_time\n", + " logging.info(\n", + " \"Time (H:M:S):\"\n", + " f\" {time.strftime('%H:%M:%S', time.gmtime(elapsed_time))}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "wop.rio.reproject(4326, nodata=255).where(lambda xx: xx != 255).hvplot(\n", + " x=\"x\", y=\"y\", rasterize=True, geo=True, tiles=\"EsriImagery\", width=600\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "shorelines = shorelines.assign(\n", + " length=shorelines.to_crs(shorelines.estimate_utm_crs()).geometry.length\n", + ")\n", + "shorelines = shorelines.assign(\n", + " time=pd.DatetimeIndex(shorelines.time).strftime(\"%Y-%m-%d\")\n", + ").astype({\"z_value\": \"f4\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "shorelines_ = shorelines.loc[(shorelines.z_value > -0.3) & (shorelines.z_value < 0.3)]\n", + "shorelines_.sort_values(\"length\", ascending=False).iloc[:50].to_crs(4326).explore(\n", + " column=\"time\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:coastal-full] *", + "language": "python", + "name": "conda-env-coastal-full-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}