Skip to content

Commit

Permalink
Fixes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgleith committed Nov 20, 2024
1 parent ec117db commit e90c970
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
8 changes: 4 additions & 4 deletions odc/geo/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class LANDSAT_C2L2_PIXEL_QA(Enum):

def bits_to_bool(
xx: DataArray,
bits: Sequence[int] | None,
bitflags: int | None,
bits: Sequence[int] | None = None,
bitflags: int | None = None,
invert: bool = False,
) -> DataArray:
"""
Expand Down Expand Up @@ -121,8 +121,8 @@ def enum_to_bool(

def scale_and_offset(
xx: DataArray | Dataset,
scale: float | None,
offset: float | None,
scale: float | None = None,
offset: float | None = None,
clip: Annotated[Sequence[int | float], 2] | None = None,
) -> DataArray | Dataset:
"""
Expand Down
37 changes: 29 additions & 8 deletions tests/test_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@
mask_invalid_data,
)

from xarray import DataArray

from xarray import DataArray, Dataset

# Top left is cloud, top right is cloud shadow
# Bottom left is both cloud and cloud shadow, bottom right is neither
xx_bits = DataArray(
[[0b00010000, 0b00001000], [0b00011000, 0b00000000]], dims=("y", "x")
[[0b00010000, 0b00001000], [0b00011000, 0b00000000]], dims=("y", "x"), attrs={"nodata": 0}
)

# Set up a 2x2 8 bit integer DataArray with some
# values set to 3 (shadow), 9 (high confidence cloud).
xx_values = DataArray([[3, 9], [3, 0]], dims=("y", "x"))
xx_values = DataArray([[3, 9], [3, 0]], dims=("y", "x"), attrs={"nodata": 0})

# Array with some zeros
xx_with_nodata = DataArray([[1, 2], [0, 0]], dims=("y", "x"), attrs={"nodata": 0})
xx_with_nodata = DataArray([[0, 1], [2, 3]], dims=("y", "x"), attrs={"nodata": 0})


# Test bits_to_bool
Expand Down Expand Up @@ -54,17 +55,37 @@ def test_scale_and_offset():
mask = scale_and_offset(xx_values, scale=1.0, offset=0.0)
assert mask.equals(DataArray([[3, 9], [3, 0]], dims=("y", "x")))

mask = scale_and_offset(xx_values, scale=None, offset=None, ignore_missing=True)
mask = scale_and_offset(xx_values)
assert mask.equals(DataArray([[3, 9], [3, 0]], dims=("y", "x")))

mask = scale_and_offset(xx_values, scale=2.0, offset=1.0)
assert mask.equals(DataArray([[7, 19], [7, 1]], dims=("y", "x")))
assert mask.equals(DataArray([[7, 19], [7, 0]], dims=("y", "x")))


# Test mask_invalid
def test_mask_invalid_data():
mask = mask_invalid_data(xx_with_nodata)
assert mask.equals(DataArray([[1.0, 2.0], [np.nan, np.nan]], dims=("y", "x")))
assert mask.equals(DataArray([[np.nan, 1.0], [2.0, 3.0]], dims=("y", "x")))

mask = mask_invalid_data(xx_with_nodata, nodata=1)
assert mask.equals(DataArray([[np.nan, 2], [0, 0]], dims=("y", "x")))
assert mask.equals(DataArray([[0, np.nan], [2, 3]], dims=("y", "x")))


# Test landsat masking
def test_mask_landsat():
xx = Dataset({"pixel_qa": xx_bits, "red": scale_and_offset(xx_with_nodata, offset=20000)})
print(xx)

xx = xx.odc.mask_ls()

assert xx["red"].equals(DataArray([[np.nan, np.nan], [np.nan, 0.3500825]], dims=("y", "x")))


def test_mask_sentinel2():
xx = Dataset({"scl": xx_values, "red": scale_and_offset(xx_with_nodata, offset=8000)})

xx = xx.odc.mask_s2()

assert xx["red"].equals(DataArray([[np.nan, np.nan], [np.nan, 0.7003]], dims=("y", "x")))

assert xx.red.odc.nodata is not None

0 comments on commit e90c970

Please sign in to comment.