Skip to content

Commit

Permalink
fix training data generation script
Browse files Browse the repository at this point in the history
fix the apache-beam installement
  • Loading branch information
biplovbhandari committed Apr 22, 2024
1 parent efca0d8 commit e8f500e
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 66 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy
tensorflow>=2.9.3
apache-beam>=2.38.0
apache-beam[gcp]>=2.38.0
earthengine-api
python-dotenv>=1.0.0
matplotlib
186 changes: 121 additions & 65 deletions workflow/v2/generate_training_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,61 @@ class TrainingDataGenerator:
A class to generate training data for machine learning models.
This class utilizes Apache Beam for data processing and Earth Engine for handling geospatial data.
"""
def __init__(self, config, include_after: bool = False, split_dir: str = "training", **kwargs):
def __init__(self, config, split_dir: str = "training", **kwargs):
"""
Constructor for the TrainingDataGenerator class.
Parameters:
include_after (bool): If True, includes "after" images in the generated data. Default is False.
config (Config): The configuration object (aces.Config) containing the necessary settings.
split_dir (str): The split directory. Default is training. Choices are training, testing, and validation.
**kwargs: Additional keyword arguments to pass to the class.
Additional keyword arguments:
- test_ratio (float): The test ratio. Default is 0.1.
- validation_ratio (float): The validation ratio. Default is 0.2.
- seed (int): The seed value for random number generation. Default is 100.
- use_service_account (bool): Whether to use the service account for authentication. Default is False.
- label (str or ee.Image): The label dataset to load from. Default is the Bhutan ACES 2 dataset.
- image (str or ee.Image): The image dataset to load from. Default is the Bhutan ACES 2 dataset.
"""
self.config = config
self.output_bucket = self.config.GCS_BUCKET
self.kernel_size = self.config.PATCH_SHAPE_SINGLE
self.grace = 10
self.scale = self.config.SCALE
self.include_after = include_after
self.test_ratio = kwargs.get("test_ratio", 0.1)
self.validation_ratio = kwargs.get("validation_ratio", 0.2)
self.seed = kwargs.get("seed", 100)
self.use_service_account = self.config.USE_SERVICE_ACCOUNT
self.split_dir = split_dir
print(f"split_dir: {self.split_dir}")

self.sample_locations = ee.FeatureCollection(kwargs.get("sampled_locations",
"projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples_clipped_new"))
self.sample_locations = self.sample_locations.randomColumn("random", self.seed)

default_label = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_label").rename("class").unmask(0, False)
label = kwargs.get("label", default_label)
if isinstance(label, str):
self.label = ee.Image(label)
elif isinstance(label, ee.Image):
self.label = label
else:
raise ValueError("Invalid label provided. Please provide a valid Earth Engine Image object or path to an Image object.")

image = kwargs.get("image")
if isinstance(image, str):
self.image = ee.Image(image)
elif isinstance(image, ee.Image):
self.image = image
else:
raise ValueError("Invalid image provided. Please provide a valid Earth Engine Image object or path to an Image object.")

def load_data(self) -> None:
"""
Load the necessary data from Earth Engine and prepare it for use.
"""
EEUtils.initialize_session(use_highvolume=True, key=self.config.EE_SERVICE_CREDENTIALS if self.use_service_account else None)
self.l1 = ee.FeatureCollection("projects/servir-sco-assets/assets/Bhutan/BT_Admin_1")
self.paro = self.l1.filter(ee.Filter.eq("ADM1_EN", "Paro"))

# self.sample_locations = ee.FeatureCollection("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples")
# self.sample_locations = ee.FeatureCollection("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples_clipped")
self.sample_locations = ee.FeatureCollection("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples_clipped_new")
self.sample_locations = self.sample_locations.randomColumn("random", self.seed)
self.training_sample_locations = self.sample_locations.filter(ee.Filter.gt("random", self.validation_ratio + self.test_ratio)) # > 0.4
print("Training sample size:", self.training_sample_locations.size().getInfo())
self.validation_sample_locations = self.sample_locations.filter(ee.Filter.lte("random", self.validation_ratio)) # <= 0.2
Expand All @@ -80,61 +103,17 @@ def load_data(self) -> None:
print("Sample size:", self.sample_size)
self.sample_locations_list = self.sample_locations.toList(self.sample_size + self.grace)

self.label = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_label").rename("class").unmask(0, False)
# self.other = self.label.remap([0, 1], [1, 0]).rename(["other"])

self.composite_after = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_after")

self.composite_before = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_before")

self.composite_during = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_during")

evi_before = EEUtils.calculate_evi(self.composite_before)
evi_during = EEUtils.calculate_evi(self.composite_during)
evi_after = EEUtils.calculate_evi(self.composite_after)

self.composite_before = self.composite_before.addBands(evi_before)
self.composite_during = self.composite_during.addBands(evi_during)
self.composite_after = self.composite_after.addBands(evi_after)

original_bands = self.composite_before.bandNames().getInfo()
lowercase_bands = [band.lower() for band in original_bands]

self.composite_before = self.composite_before.select(original_bands, lowercase_bands)
self.composite_after = self.composite_after.select(original_bands, lowercase_bands)
self.composite_during = self.composite_during.select(original_bands, lowercase_bands)

self.composite_before = self.composite_before.regexpRename("$(.*)", "_before")
self.composite_after = self.composite_after.regexpRename("$(.*)", "_after")
self.composite_during = self.composite_during.regexpRename("$(.*)", "_during")

self.sentinel1_asc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscBefore")
self.sentinel1_asc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscDuring")

self.sentinel1_desc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescBefore")
self.sentinel1_desc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescDuring")

self.elevation = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/elevationParo")
self.slope = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/slopeParo")

if self.include_after:
self.image = self.composite_before.addBands(self.composite_during).addBands(self.composite_after).toFloat()
else:
self.image = self.composite_before.addBands(self.composite_during).toFloat()

if self.config.USE_S1:
self.image = self.image.addBands(self.sentinel1_asc_before_composite).addBands(self.sentinel1_asc_during_composite).addBands(self.sentinel1_desc_before_composite).addBands(self.sentinel1_desc_during_composite).toFloat()
self.config.FEATURES.extend(["vv_asc_before", "vh_asc_before", "vv_asc_during", "vh_asc_during",
"vv_desc_before", "vh_desc_before", "vv_desc_during", "vh_desc_during"])
"vv_desc_before", "vh_desc_before", "vv_desc_during", "vh_desc_during"])

if self.config.USE_ELEVATION:
self.image = self.image.addBands(self.elevation).addBands(self.slope).toFloat()
self.config.FEATURES.extend(["elevation", "slope"])

self.image = self.image.select(self.config.FEATURES)
self.image = self.image.addBands(self.label).toFloat()

self.image = self.image.addBands(label).toFloat()
print("Image bands:", self.image.bandNames().getInfo())

self.selectors = self.image.bandNames().getInfo()

proj = ee.Projection("EPSG:4326").atScale(10).getInfo()
Expand Down Expand Up @@ -190,9 +169,9 @@ def _generate_data(image, data, selectors, use_service_account, prefix):
if self.split_dir == "training":
_generate_data(self.image, self.training_sample_locations, self.selectors, self.use_service_account, training_file_prefix)
elif self.split_dir == "validation":
_generate_data(self.image, self.validation_sample_locations, self.selectors, self.use_service_account, testing_file_prefix)
_generate_data(self.image, self.validation_sample_locations, self.selectors, self.use_service_account, validation_file_prefix)
elif self.split_dir == "testing":
_generate_data(self.image, self.test_sample_locations, self.selectors, self.use_service_account, validation_file_prefix)
_generate_data(self.image, self.test_sample_locations, self.selectors, self.use_service_account, testing_file_prefix)
else:
print("Invalid split name specified. Choices are training, validation, and testing. Exiting..")
exit(1)
Expand Down Expand Up @@ -344,9 +323,15 @@ def run_neighborhood_generator(self) -> None:
use as python generate_training_data.py --mode patch \n
or python generate_training_data.py --m point""")

parser.add_argument("-c", "--config", "--c", help="`.env` file to load config from")
parser.add_argument("--config", help="`.env` file to load config from")

parser.add_argument("--split_directory", help="The split directory. Default is training. Choices are training, testing, and validation \n")

parser.add_argument("-s", "--split_directory", "--s", help="The split directory. Default is training. Choices are training, testing, and validation \n")
parser.add_argument("--sample_data", help="The sampled dataset to load from. Please provide the path only to the ee.FeatureCollection to load from \n")

parser.add_argument("--label_data", help="The label dataset to load from. Please either provide the path or directly the ee.Image Object to load from \n")

parser.add_argument("--image_data", help="The image dataset to load from. Please either provide the path or directly the ee.Image Object to load from \n")

# Read arguments from command line
mode = "neighborhood"
Expand All @@ -355,6 +340,29 @@ def run_neighborhood_generator(self) -> None:
else:
print("No mode specified, defaulting to neighborhood.")

split_dir = "training"
if parser.parse_args().split_directory:
split_dir = parser.parse_args().split_directory
else:
print("No split directory specified, defaulting to `training`.")

# sample locations
sample_locations = "projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_samples_clipped_new"
if parser.parse_args().sample_data:
sample_locations = parser.parse_args().sample_data
else:
print("No dataset specified, defaulting to what we have here.")

EEUtils.initialize_session(use_highvolume=True)
# label dataset
if parser.parse_args().label_data:
label = parser.parse_args().label_data
else:
print("No label dataset specified, defaulting to what we have here.")
label = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/paro_2021_all_class_label").rename("class").unmask(0, False)
# other = label.remap([0, 1], [1, 0]).rename(["other"])

# config file
config = "../config.env"
if parser.parse_args().config:
config = parser.parse_args().config
Expand All @@ -363,14 +371,62 @@ def run_neighborhood_generator(self) -> None:

config = Config(config)

split_dir = "training"
if parser.parse_args().split_directory:
split_dir = parser.parse_args().split_directory

# image
composite_after = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_after")
composite_before = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_before")
composite_during = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/Paro_Rice_Composite_2021/composite_during")

evi_before = EEUtils.calculate_evi(composite_before)
evi_during = EEUtils.calculate_evi(composite_during)
evi_after = EEUtils.calculate_evi(composite_after)

composite_before = composite_before.addBands(evi_before)
composite_during = composite_during.addBands(evi_during)
composite_after = composite_after.addBands(evi_after)

original_bands = composite_before.bandNames().getInfo()
lowercase_bands = [band.lower() for band in original_bands]

composite_before = composite_before.select(original_bands, lowercase_bands)
composite_after = composite_after.select(original_bands, lowercase_bands)
composite_during = composite_during.select(original_bands, lowercase_bands)

composite_before = composite_before.regexpRename("$(.*)", "_before")
composite_after = composite_after.regexpRename("$(.*)", "_after")
composite_during = composite_during.regexpRename("$(.*)", "_during")

sentinel1_asc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscBefore")
sentinel1_asc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscDuring")

sentinel1_desc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescBefore")
sentinel1_desc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescDuring")

elevation = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/elevationParo")
slope = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/slopeParo")

include_after = False
if include_after:
image = composite_before.addBands(composite_during).addBands(composite_after).toFloat()
else:
print("No split directory specified, defaulting to `training`.")
image = composite_before.addBands(composite_during).toFloat()

if config.USE_S1:
image = image.addBands(sentinel1_asc_before_composite).addBands(sentinel1_asc_during_composite)\
.addBands(sentinel1_desc_before_composite).addBands(sentinel1_desc_during_composite).toFloat()

if config.USE_ELEVATION:
image = image.addBands(elevation).addBands(slope).toFloat()

# image dataset
if parser.parse_args().image_data:
image = parser.parse_args().image_data
else:
print("No image dataset specified, defaulting to what we have here.")
image = image

# more settings can be applied here
generator = TrainingDataGenerator(config=config, split_dir=split_dir)
generator = TrainingDataGenerator(config=config, split_dir=split_dir, sample_locations=sample_locations, image=image, label=label)
if mode == "patch":
generator.run_patch_generator()
elif mode == "patch_seed":
Expand Down

0 comments on commit e8f500e

Please sign in to comment.