Skip to content

Commit

Permalink
Revise after feedback from kirill. Add mask_clouds, mask_ls and mask_…
Browse files Browse the repository at this point in the history
…s2 functions
  • Loading branch information
alexgleith committed Nov 20, 2024
1 parent 479dc32 commit ec117db
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 25 deletions.
17 changes: 15 additions & 2 deletions odc/geo/_xr_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@
from .masking import (
bits_to_bool,
enum_to_bool,
mask_invalid_data,
mask_clouds,
mask_ls,
mask_s2,
scale_and_offset,
scale_and_offset_dataset,
)
from .overlap import compute_output_geobox
from .roi import roi_is_empty
Expand Down Expand Up @@ -1065,6 +1068,8 @@ def nodata(self, value: Nodata):

enum_to_bool = _wrap_op(enum_to_bool)

mask_invalid_data = _wrap_op(mask_invalid_data)

if have.rasterio:
write_cog = _wrap_op(write_cog)
to_cog = _wrap_op(to_cog)
Expand Down Expand Up @@ -1105,7 +1110,15 @@ def to_rgba(
) -> xarray.DataArray:
return to_rgba(self._xx, bands=bands, vmin=vmin, vmax=vmax)

scale_and_offset = _wrap_op(scale_and_offset_dataset)
scale_and_offset = _wrap_op(scale_and_offset)

mask_invalid_data = _wrap_op(mask_invalid_data)

mask_clouds = _wrap_op(mask_clouds)

mask_ls = _wrap_op(mask_ls)

mask_s2 = _wrap_op(mask_s2)


ODCExtensionDs.to_rgba.__doc__ = to_rgba.__doc__
Expand Down
264 changes: 242 additions & 22 deletions odc/geo/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,67 @@
Functions around supporting cloud masking.
"""

from typing import Annotated, Any, Callable, Sequence
import numpy as np
from xarray import DataArray, Dataset

from enum import Enum


class SENTINEL2_L2A_SCL(Enum):
"""
Sentinel-2 Scene Classification Layer (SCL) values.
"""

NO_DATA = 0
SATURATED_OR_DEFECTIVE = 1
DARK_AREA_PIXELS = 2
CLOUD_SHADOWS = 3
VEGETATION = 4
NOT_VEGETATED = 5
WATER = 6
UNCLASSIFIED = 7
CLOUD_MEDIUM_PROBABILITY = 8
CLOUD_HIGH_PROBABILITY = 9
THIN_CIRRUS = 10
SNOW = 11


SENTINEL2_L2A_SCALE = 0.0001
SENTINEL2_L2A_OFFSET = -0.1


class LANDSAT_C2L2_PIXEL_QA(Enum):
"""
Landsat Collection 2 Surface Reflectance Pixel Quality values.
"""

NO_DATA = 0
DILATED_CLOUD = 1
CIRRUS = 2
CLOUD = 3
CLOUD_SHADOW = 4
SNOW = 5
CLEAR = 6
WATER = 7
# Not sure how to implement these yet...
# CLOUD_CONFIDENCE = [8, 9]
# CLOUD_SHADOW_CONFIDENCE = [10, 11]
# SNOW_ICE_CONFIDENCE = [12, 13]
# CIRRUS_CONFIDENCE = [14, 15]


LANDSAT_C2L2_SCALE = 0.0000275
LANDSAT_C2L2_OFFSET = -0.2

# TODO: QA_RADSAT and QA_AEROSOL for Landsat Collection 2 Surface Reflectance


def bits_to_bool(
xx: DataArray, bits: list[int] | None, bitflags: int | None, invert: bool = False
xx: DataArray,
bits: Sequence[int] | None,
bitflags: int | None,
invert: bool = False,
) -> DataArray:
"""
Convert integer array into boolean array using bitmasks.
Expand Down Expand Up @@ -43,7 +99,9 @@ def bits_to_bool(
return mask


def enum_to_bool(xx: DataArray, values: list, invert: bool = False) -> DataArray:
def enum_to_bool(
xx: DataArray, values: Sequence[Any], invert: bool = False
) -> DataArray:
"""
Convert array into boolean array using a list of invalid values.
Expand All @@ -62,11 +120,11 @@ def enum_to_bool(xx: DataArray, values: list, invert: bool = False) -> DataArray


def scale_and_offset(
xx: DataArray,
xx: DataArray | Dataset,
scale: float | None,
offset: float | None,
ignore_missing: bool = False,
) -> DataArray:
clip: Annotated[Sequence[int | float], 2] | None = None,
) -> DataArray | Dataset:
"""
Apply scale and offset to the DataArray. Leave scale and offset blank to use
the values from the DataArray's attrs.
Expand All @@ -77,7 +135,14 @@ def scale_and_offset(
:return: DataArray with scaled and offset values
"""

# Scales and offsets is used by GDAL.
# For the Dataset case, we do this recursively for all variables.
if type(xx) is Dataset:
for var in xx.data_vars:
xx[var] = scale_and_offset(xx[var], scale, offset, clip=clip)

return xx

# "Scales" and "offsets" is used by GDAL.
if scale is None:
scale = xx.attrs.get("scales")

Expand All @@ -91,31 +156,186 @@ def scale_and_offset(
if offset is None and scale is not None:
offset = 0.0

# Store the nodata values to apply to the result
nodata = xx.odc.nodata

# Stash the attributes
attrs = {k: v for k, v in xx.attrs.items()}

if nodata is not None:
nodata_mask = xx == nodata

# If both are missing, we can just return the original array.
if scale is not None and offset is not None:
xx = xx * scale + offset
else:
if not ignore_missing:
raise ValueError(
"Scale and offset not provided and not found in attrs.scales and attrs.offset"
)
xx = (xx * scale) + offset

if clip is not None:
assert len(clip) == 2, "Clip must be a list of two values"
xx = xx.clip(clip[0], clip[1])

# Re-attach nodata
if nodata is not None:
xx = xx.where(~nodata_mask, other=nodata)

xx.attrs = attrs # Not sure if this is required

return xx


def mask_invalid_data(
xx: DataArray | Dataset,
nodata: int | float | None = None,
skip_bands: Sequence[str] = [],
) -> DataArray | Dataset:
"""
Mask out invalid data values.
:param xx: DataArray
:return: DataArray with invalid data values converted to np.nan. Note this will change the dtype to float.
"""
if type(xx) is Dataset:
for var in xx.data_vars:
if var not in skip_bands:
xx[var] = mask_invalid_data(xx[var], nodata)
return xx

if nodata is None:
nodata = xx.odc.nodata

assert nodata is not None, "Nodata value must be provided or available in attrs"

xx = xx.where(xx != nodata)
xx.odc.nodata = np.nan

return xx


def scale_and_offset_dataset(
xx: Dataset, scale: float | None, offset: float | None
def mask_clouds(
xx: Dataset,
qa_name: str,
scale: float,
offset: float,
clip: tuple,
mask_func: Callable = enum_to_bool, # Pass the function for enum-based masks (bits_to_bool or enum_to_bool)
mask_func_args: dict = {}, # Pass the arguments for the mask function
apply_mask: bool = True,
keep_qa: bool = False,
return_mask: bool = False,
) -> Dataset:
"""
Apply scale and offset to the Dataset. Leave scale and offset blank to use
the values from each DataArray's attrs.
General cloud masking function for both Landsat and Sentinel-2 products.
:param xx: Dataset with integer values
:param scale: Scale factor
:param offset: Offset
:return: Dataset with scaled and offset values
:param xx: Dataset or DataArray
:param qa_name: QA band to use for masking
:param mask_classes: List of mask class values (e.g., cloud, cloud shadow)
:param scale: Scale value for the dataset
:param offset: Offset value for the dataset
:param clip: Clip range for the data
:param includ_cirrus: Whether to include cirrus in the mask
:param apply_mask: Apply the cloud mask to the data, erasing data where clouds are present
:param keep_qa: Keep the QA band in the output
:param return_mask: Return the mask as a variable called "mask"
:param enum_to_bool_func: Function to convert bit values to boolean mask (either bits_to_bool or enum_to_bool)
:return: Dataset or DataArray with invalid data values converted to np.nan. Note this will change the dtype to float.
"""
attrs = {k: v for k, v in xx.attrs.items()}

# Retrieve the QA band
try:
qa = xx[qa_name]
except KeyError:
raise KeyError(f"QA band '{qa_name}' not found in dataset.")

# Drop the QA band and apply other preprocessing steps
xx = xx.drop_vars(qa_name)
xx = mask_invalid_data(xx)
xx = scale_and_offset(xx, scale=scale, offset=offset, clip=clip)

# Generate the mask
mask = mask_func(qa, **mask_func_args)

# Apply the mask if required
if apply_mask:
xx = xx.where(~mask)

# Set 'nodata' to np.nan for all variables
for var in xx.data_vars:
xx[var] = scale_and_offset(xx[var], scale, offset, ignore_missing=True)
xx[var].odc.nodata = np.nan

return xx
# Optionally keep the QA band
if keep_qa:
xx[qa_name] = qa

# Optionally return the mask
if return_mask:
xx["mask"] = mask

xx.attrs = attrs

return xx # type: ignore


def mask_ls(
xx: Dataset,
qa_name: str = "pixel_qa",
include_cirrus: bool = False,
apply_mask: bool = True,
keep_qa: bool = False,
return_mask: bool = False,
) -> Dataset:
"""
Perform cloud masking for Landsat Collection 2 products.
"""
mask_bits = [
LANDSAT_C2L2_PIXEL_QA.CLOUD.value,
LANDSAT_C2L2_PIXEL_QA.CLOUD_SHADOW.value,
]
if include_cirrus:
mask_bits.append(LANDSAT_C2L2_PIXEL_QA.CIRRUS.value)

return mask_clouds(
xx=xx,
qa_name=qa_name,
scale=LANDSAT_C2L2_SCALE,
offset=LANDSAT_C2L2_OFFSET,
clip=(0.0, 1.0),
mask_func=bits_to_bool,
mask_func_args={"bits": mask_bits},
apply_mask=apply_mask,
keep_qa=keep_qa,
return_mask=return_mask,
)


def mask_s2(
xx: Dataset,
qa_name: str = "scl",
include_cirrus: bool = False,
apply_mask: bool = True,
keep_qa: bool = False,
return_mask: bool = False,
) -> Dataset:
"""
Perform cloud masking for Sentinel-2 L2A products.
"""
mask_values = [
SENTINEL2_L2A_SCL.SATURATED_OR_DEFECTIVE.value,
SENTINEL2_L2A_SCL.CLOUD_MEDIUM_PROBABILITY.value,
SENTINEL2_L2A_SCL.CLOUD_HIGH_PROBABILITY.value,
SENTINEL2_L2A_SCL.CLOUD_SHADOWS.value,
]
if include_cirrus:
mask_values.append(SENTINEL2_L2A_SCL.THIN_CIRRUS.value)

return mask_clouds(
xx=xx,
qa_name=qa_name,
scale=SENTINEL2_L2A_SCALE,
offset=SENTINEL2_L2A_OFFSET,
mask_func=enum_to_bool,
mask_func_args={"values": mask_values},
clip=(0.0, 1.0),
apply_mask=apply_mask,
keep_qa=keep_qa,
return_mask=return_mask,
)
20 changes: 19 additions & 1 deletion tests/test_masking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from odc.geo.masking import bits_to_bool, enum_to_bool, scale_and_offset
import numpy as np
from odc.geo.masking import (
bits_to_bool,
enum_to_bool,
scale_and_offset,
mask_invalid_data,
)

from xarray import DataArray

Expand All @@ -12,6 +18,9 @@
# values set to 3 (shadow), 9 (high confidence cloud).
xx_values = DataArray([[3, 9], [3, 0]], dims=("y", "x"))

# Array with some zeros
xx_with_nodata = DataArray([[1, 2], [0, 0]], dims=("y", "x"), attrs={"nodata": 0})


# Test bits_to_bool
def test_bits_to_bool():
Expand Down Expand Up @@ -50,3 +59,12 @@ def test_scale_and_offset():

mask = scale_and_offset(xx_values, scale=2.0, offset=1.0)
assert mask.equals(DataArray([[7, 19], [7, 1]], dims=("y", "x")))


# Test mask_invalid
def test_mask_invalid_data():
mask = mask_invalid_data(xx_with_nodata)
assert mask.equals(DataArray([[1.0, 2.0], [np.nan, np.nan]], dims=("y", "x")))

mask = mask_invalid_data(xx_with_nodata, nodata=1)
assert mask.equals(DataArray([[np.nan, 2], [0, 0]], dims=("y", "x")))

0 comments on commit ec117db

Please sign in to comment.