Skip to content

Commit

Permalink
additional typing annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
myscon committed Apr 24, 2024
1 parent d7739a7 commit 64414db
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions src/ltgee/landtrendr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ def mask_labels(self):
return self._mask_labels

@mask_labels.setter
def mask_labels(self, mask_labels):
def mask_labels(self, mask_labels: list):
assert all([_ in self._mask_options for _ in mask_labels]
), f"mask_labels must be a subset of {self._mask_options}"
self._mask_labels = mask_labels

def _build_sr_collection(self, debug=False):
def _build_sr_collection(self, debug: Optional[bool] = False):
"""
Builds a medoid composite of Landsat surface reflectance TM-equivalent bands 1,2,3,4,5,7.
This collection can be useful outside of use by LandTrendr, but is also the base for creating the input collection for LandTrendr.
Expand All @@ -78,7 +78,10 @@ def _build_sr_collection(self, debug=False):
return ee.ImageCollection(
[self._build_medoid_mosaic(year, dummy_collection, debug) for year in range(self.start_date.year, self.end_date.year + 1)])

def _build_medoid_mosaic(self, year, dummy_collection, debug=False):
def _build_medoid_mosaic(self,
year: int,
dummy_collection: ee.ImageCollection,
debug: Optional[bool] = False):
collection = self._get_combined_sr_collection(year)
image_count = collection.size()
final_collection = ee.ImageCollection(ee.Algorithms.If(
Expand All @@ -100,14 +103,14 @@ def _build_medoid_mosaic(self, year, dummy_collection, debug=False):
.set('system:time_start', ee.Date.fromYMD(year, self.start_date.month, self.start_date.day).millis())\
.toUint16()

def _get_combined_sr_collection(self, year):
def _get_combined_sr_collection(self, year: int):
lt5 = self._get_sr_collection(year, 'LT05')
le7 = self._get_sr_collection(year, 'LE07')
lc8 = self._get_sr_collection(year, 'LC08')
lc9 = self._get_sr_collection(year, 'LC09')
return lt5.merge(le7).merge(lc8).merge(lc9)

def _get_sr_collection(self, year, sensor):
def _get_sr_collection(self, year: int, sensor: str):
if self.start_date.month > self.end_date.month:
start_date = ee.Date.fromYMD(
year - 1, self.start_date.month, self.start_date.day)
Expand All @@ -125,7 +128,7 @@ def _get_sr_collection(self, year, sensor):
.set("system:time_start", start_date.millis())
return self._remove_images(sr_collection)

def _preprocess_image(self, image, sensor):
def _preprocess_image(self, image: int, sensor: str):
# Accounting for band shift between landsat difference landsat images
if sensor == 'LC08' or sensor == 'LC09':
dat = image.select(['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'],
Expand All @@ -138,10 +141,10 @@ def _preprocess_image(self, image, sensor):
dat = self._apply_masks(image.select('QA_PIXEL'), dat)
return dat

def _scale_unmask_image(self, image):
def _scale_unmask_image(self, image: ee.Image):
return image.multiply(0.0000275).add(-0.2).multiply(10000).toUint16().unmask()

def _apply_masks(self, qa, dat):
def _apply_masks(self, qa: ee.Image, dat: ee.Image):
mask = ee.Image(1)
# TODO: Refactor to allow dynamically allow new masks
for mask_label in self.mask_labels:
Expand All @@ -160,7 +163,7 @@ def _apply_masks(self, qa, dat):
mask = mask.mask(forest_mask(self.area_of_interest))
return dat.mask(mask)

def _remove_images(self, collection):
def _remove_images(self, collection: ee.ImageCollection):
"""
Removes images from a collection based on the given exclude criteria.
Expand Down Expand Up @@ -223,7 +226,7 @@ def index(self):
return self._index

@index.setter
def index(self, index):
def index(self, index: str):
assert index in self._valid_indices or index in self._valid_indices_alt, f"Index must be one of {self._valid_indices} or {self._valid_indices_alt}"
self._index = index

Expand All @@ -232,7 +235,7 @@ def ftv_list(self):
return self._ftv_list

@ftv_list.setter
def ftv_list(self, ftv_list):
def ftv_list(self, ftv_list: list):
assert all([_ in self._valid_indices for _ in ftv_list]
), f"ftv_list must be a subset of {self._valid_indices}"
self._ftv_list = ftv_list
Expand Down Expand Up @@ -595,7 +598,7 @@ def run_params(self):
return self._run_params

@run_params.setter
def run_params(self, run_params):
def run_params(self, run_params: dict):
assert all([_ in self._default_run_params.keys()
for _ in run_params.keys()]), f"run_params must be a subset of {self._default_run_params.keys()}"
if hasattr(self, '_data'):
Expand Down Expand Up @@ -946,7 +949,7 @@ def get_fitted_rgb_col(self,

return rgb_col

def _apply_mmu(self, image, mmu_value):
def _apply_mmu(self, image: ee.Image, mmu_value: int) -> ee.Image:
mmu_image = image.select([0])\
.gte(ee.Number(1))\
.selfMask()\
Expand Down

0 comments on commit 64414db

Please sign in to comment.