diff --git a/requirements.txt b/requirements.txt index 5481a8e..7c3de82 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy tensorflow>=2.9.3 -apache-beam>=2.38.0 +apache-beam[gcp]>=2.38.0 earthengine-api python-dotenv>=1.0.0 matplotlib diff --git a/workflow/v2/generate_training_patches.py b/workflow/v2/generate_training_patches.py index bcda3ce..655871f 100644 --- a/workflow/v2/generate_training_patches.py +++ b/workflow/v2/generate_training_patches.py @@ -37,19 +37,27 @@ class TrainingDataGenerator: A class to generate training data for machine learning models. This class utilizes Apache Beam for data processing and Earth Engine for handling geospatial data. """ - def __init__(self, config, include_after: bool = False, split_dir: str = "training", **kwargs): + def __init__(self, config, split_dir: str = "training", **kwargs): """ Constructor for the TrainingDataGenerator class. Parameters: - include_after (bool): If True, includes "after" images in the generated data. Default is False. + config (Config): The configuration object (aces.Config) containing the necessary settings. + split_dir (str): The split directory. Default is training. Choices are training, testing, and validation. + **kwargs: Additional keyword arguments to pass to the class. + Additional keyword arguments: + - test_ratio (float): The test ratio. Default is 0.1. + - validation_ratio (float): The validation ratio. Default is 0.2. + - seed (int): The seed value for random number generation. Default is 100. + - use_service_account (bool): Whether to use the service account for authentication. Default is False. + - label (str or ee.Image): The label dataset to load from. Default is the Bhutan ACES 2 dataset. + - image (str or ee.Image): The image dataset to load from. Default is the Bhutan ACES 2 dataset. """ self.config = config self.output_bucket = self.config.GCS_BUCKET self.kernel_size = self.config.PATCH_SHAPE_SINGLE self.grace = 10 self.scale = self.config.SCALE - self.include_after = include_after self.test_ratio = kwargs.get("test_ratio", 0.1) self.validation_ratio = kwargs.get("validation_ratio", 0.2) self.seed = kwargs.get("seed", 100) @@ -57,18 +65,33 @@ def __init__(self, config, include_after: bool = False, split_dir: str = "traini self.split_dir = split_dir print(f"split_dir: {self.split_dir}") + self.sample_locations = ee.FeatureCollection(kwargs.get("sampled_locations", + "projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples_clipped_new")) + self.sample_locations = self.sample_locations.randomColumn("random", self.seed) + + default_label = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_label").rename("class").unmask(0, False) + label = kwargs.get("label", default_label) + if isinstance(label, str): + self.label = ee.Image(label) + elif isinstance(label, ee.Image): + self.label = label + else: + raise ValueError("Invalid label provided. Please provide a valid Earth Engine Image object or path to an Image object.") + + image = kwargs.get("image") + if isinstance(image, str): + self.image = ee.Image(image) + elif isinstance(image, ee.Image): + self.image = image + else: + raise ValueError("Invalid image provided. Please provide a valid Earth Engine Image object or path to an Image object.") + def load_data(self) -> None: """ Load the necessary data from Earth Engine and prepare it for use. """ EEUtils.initialize_session(use_highvolume=True, key=self.config.EE_SERVICE_CREDENTIALS if self.use_service_account else None) - self.l1 = ee.FeatureCollection("projects/servir-sco-assets/assets/Bhutan/BT_Admin_1") - self.paro = self.l1.filter(ee.Filter.eq("ADM1_EN", "Paro")) - # self.sample_locations = ee.FeatureCollection("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples") - # self.sample_locations = ee.FeatureCollection("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples_clipped") - self.sample_locations = ee.FeatureCollection("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples_clipped_new") - self.sample_locations = self.sample_locations.randomColumn("random", self.seed) self.training_sample_locations = self.sample_locations.filter(ee.Filter.gt("random", self.validation_ratio + self.test_ratio)) # > 0.4 print("Training sample size:", self.training_sample_locations.size().getInfo()) self.validation_sample_locations = self.sample_locations.filter(ee.Filter.lte("random", self.validation_ratio)) # <= 0.2 @@ -80,61 +103,17 @@ def load_data(self) -> None: print("Sample size:", self.sample_size) self.sample_locations_list = self.sample_locations.toList(self.sample_size + self.grace) - self.label = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_label").rename("class").unmask(0, False) - # self.other = self.label.remap([0, 1], [1, 0]).rename(["other"]) - - self.composite_after = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_after") - - self.composite_before = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_before") - - self.composite_during = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_during") - - evi_before = EEUtils.calculate_evi(self.composite_before) - evi_during = EEUtils.calculate_evi(self.composite_during) - evi_after = EEUtils.calculate_evi(self.composite_after) - - self.composite_before = self.composite_before.addBands(evi_before) - self.composite_during = self.composite_during.addBands(evi_during) - self.composite_after = self.composite_after.addBands(evi_after) - - original_bands = self.composite_before.bandNames().getInfo() - lowercase_bands = [band.lower() for band in original_bands] - - self.composite_before = self.composite_before.select(original_bands, lowercase_bands) - self.composite_after = self.composite_after.select(original_bands, lowercase_bands) - self.composite_during = self.composite_during.select(original_bands, lowercase_bands) - - self.composite_before = self.composite_before.regexpRename("$(.*)", "_before") - self.composite_after = self.composite_after.regexpRename("$(.*)", "_after") - self.composite_during = self.composite_during.regexpRename("$(.*)", "_during") - - self.sentinel1_asc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscBefore") - self.sentinel1_asc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscDuring") - - self.sentinel1_desc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescBefore") - self.sentinel1_desc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescDuring") - - self.elevation = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/elevationParo") - self.slope = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/slopeParo") - - if self.include_after: - self.image = self.composite_before.addBands(self.composite_during).addBands(self.composite_after).toFloat() - else: - self.image = self.composite_before.addBands(self.composite_during).toFloat() - if self.config.USE_S1: - self.image = self.image.addBands(self.sentinel1_asc_before_composite).addBands(self.sentinel1_asc_during_composite).addBands(self.sentinel1_desc_before_composite).addBands(self.sentinel1_desc_during_composite).toFloat() self.config.FEATURES.extend(["vv_asc_before", "vh_asc_before", "vv_asc_during", "vh_asc_during", - "vv_desc_before", "vh_desc_before", "vv_desc_during", "vh_desc_during"]) + "vv_desc_before", "vh_desc_before", "vv_desc_during", "vh_desc_during"]) if self.config.USE_ELEVATION: - self.image = self.image.addBands(self.elevation).addBands(self.slope).toFloat() self.config.FEATURES.extend(["elevation", "slope"]) self.image = self.image.select(self.config.FEATURES) - self.image = self.image.addBands(self.label).toFloat() - + self.image = self.image.addBands(label).toFloat() print("Image bands:", self.image.bandNames().getInfo()) + self.selectors = self.image.bandNames().getInfo() proj = ee.Projection("EPSG:4326").atScale(10).getInfo() @@ -190,9 +169,9 @@ def _generate_data(image, data, selectors, use_service_account, prefix): if self.split_dir == "training": _generate_data(self.image, self.training_sample_locations, self.selectors, self.use_service_account, training_file_prefix) elif self.split_dir == "validation": - _generate_data(self.image, self.validation_sample_locations, self.selectors, self.use_service_account, testing_file_prefix) + _generate_data(self.image, self.validation_sample_locations, self.selectors, self.use_service_account, validation_file_prefix) elif self.split_dir == "testing": - _generate_data(self.image, self.test_sample_locations, self.selectors, self.use_service_account, validation_file_prefix) + _generate_data(self.image, self.test_sample_locations, self.selectors, self.use_service_account, testing_file_prefix) else: print("Invalid split name specified. Choices are training, validation, and testing. Exiting..") exit(1) @@ -344,9 +323,15 @@ def run_neighborhood_generator(self) -> None: use as python generate_training_data.py --mode patch \n or python generate_training_data.py --m point""") - parser.add_argument("-c", "--config", "--c", help="`.env` file to load config from") + parser.add_argument("--config", help="`.env` file to load config from") + + parser.add_argument("--split_directory", help="The split directory. Default is training. Choices are training, testing, and validation \n") - parser.add_argument("-s", "--split_directory", "--s", help="The split directory. Default is training. Choices are training, testing, and validation \n") + parser.add_argument("--sample_data", help="The sampled dataset to load from. Please provide the path only to the ee.FeatureCollection to load from \n") + + parser.add_argument("--label_data", help="The label dataset to load from. Please either provide the path or directly the ee.Image Object to load from \n") + + parser.add_argument("--image_data", help="The image dataset to load from. Please either provide the path or directly the ee.Image Object to load from \n") # Read arguments from command line mode = "neighborhood" @@ -355,6 +340,29 @@ def run_neighborhood_generator(self) -> None: else: print("No mode specified, defaulting to neighborhood.") + split_dir = "training" + if parser.parse_args().split_directory: + split_dir = parser.parse_args().split_directory + else: + print("No split directory specified, defaulting to `training`.") + + # sample locations + sample_locations = "projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples_clipped_new" + if parser.parse_args().sample_data: + sample_locations = parser.parse_args().sample_data + else: + print("No dataset specified, defaulting to what we have here.") + + EEUtils.initialize_session(use_highvolume=True) + # label dataset + if parser.parse_args().label_data: + label = parser.parse_args().label_data + else: + print("No label dataset specified, defaulting to what we have here.") + label = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_label").rename("class").unmask(0, False) + # other = label.remap([0, 1], [1, 0]).rename(["other"]) + + # config file config = "../config.env" if parser.parse_args().config: config = parser.parse_args().config @@ -363,14 +371,62 @@ def run_neighborhood_generator(self) -> None: config = Config(config) - split_dir = "training" - if parser.parse_args().split_directory: - split_dir = parser.parse_args().split_directory + + # image + composite_after = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_after") + composite_before = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_before") + composite_during = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_during") + + evi_before = EEUtils.calculate_evi(composite_before) + evi_during = EEUtils.calculate_evi(composite_during) + evi_after = EEUtils.calculate_evi(composite_after) + + composite_before = composite_before.addBands(evi_before) + composite_during = composite_during.addBands(evi_during) + composite_after = composite_after.addBands(evi_after) + + original_bands = composite_before.bandNames().getInfo() + lowercase_bands = [band.lower() for band in original_bands] + + composite_before = composite_before.select(original_bands, lowercase_bands) + composite_after = composite_after.select(original_bands, lowercase_bands) + composite_during = composite_during.select(original_bands, lowercase_bands) + + composite_before = composite_before.regexpRename("$(.*)", "_before") + composite_after = composite_after.regexpRename("$(.*)", "_after") + composite_during = composite_during.regexpRename("$(.*)", "_during") + + sentinel1_asc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscBefore") + sentinel1_asc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscDuring") + + sentinel1_desc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescBefore") + sentinel1_desc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescDuring") + + elevation = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/elevationParo") + slope = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/slopeParo") + + include_after = False + if include_after: + image = composite_before.addBands(composite_during).addBands(composite_after).toFloat() else: - print("No split directory specified, defaulting to `training`.") + image = composite_before.addBands(composite_during).toFloat() + + if config.USE_S1: + image = image.addBands(sentinel1_asc_before_composite).addBands(sentinel1_asc_during_composite)\ + .addBands(sentinel1_desc_before_composite).addBands(sentinel1_desc_during_composite).toFloat() + + if config.USE_ELEVATION: + image = image.addBands(elevation).addBands(slope).toFloat() + + # image dataset + if parser.parse_args().image_data: + image = parser.parse_args().image_data + else: + print("No image dataset specified, defaulting to what we have here.") + image = image # more settings can be applied here - generator = TrainingDataGenerator(config=config, split_dir=split_dir) + generator = TrainingDataGenerator(config=config, split_dir=split_dir, sample_locations=sample_locations, image=image, label=label) if mode == "patch": generator.run_patch_generator() elif mode == "patch_seed":