Skip to content

Commit

Permalink
Latest updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Nov 23, 2024
1 parent e7d2298 commit 8b6c18a
Show file tree
Hide file tree
Showing 10 changed files with 412 additions and 95 deletions.
1 change: 0 additions & 1 deletion conda_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ channels:
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
- python
- pytorch::pytorch
Expand Down
48 changes: 24 additions & 24 deletions gprof_nn/config_files/gprof_nn_3d_inference.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tile_size = [128, 64]
tile_size = [128, 128]
spatial_overlap = 16
input_loader = "gprof_nn.retrieval.GPROFNNInputLoader"
input_loader_args = {{config="3d", needs_ancillary={ancillary}}}
Expand All @@ -8,26 +8,26 @@ surface_precip = "ExpectedValue"
surface_precip_terciles = {{retrieval_output="Quantiles", tau=[0.33, 0.66]}}
probability_of_precipitation = {{retrieval_output="ExceedanceProbability", threshold=1e-3}}

[retrieval_output.convective_precip]
convective_precip = "ExpectedValue"

[retrieval_output.rain_water_path]
rain_water_path = "ExpectedValue"

[retrieval_output.ice_water_path]
ice_water_path = "ExpectedValue"

[retrieval_output.cloud_water_path]
cloud_water_path = "ExpectedValue"

[retrieval_output.rain_water_content]
rain_water_content = "ExpectedValue"

[retrieval_output.snow_water_content]
snow_water_content = "ExpectedValue"

[retrieval_output.cloud_water_content]
cloud_water_content = "ExpectedValue"

[retrieval_output.latent_heat]
latent_heating = "ExpectedValue"
#[retrieval_output.convective_precip]
#convective_precip = "ExpectedValue"
#
#[retrieval_output.rain_water_path]
#rain_water_path = "ExpectedValue"
#
#[retrieval_output.ice_water_path]
#ice_water_path = "ExpectedValue"
#
#[retrieval_output.cloud_water_path]
#cloud_water_path = "ExpectedValue"
#
#[retrieval_output.rain_water_content]
#rain_water_content = "ExpectedValue"
#
#[retrieval_output.snow_water_content]
#snow_water_content = "ExpectedValue"
#
#[retrieval_output.cloud_water_content]
#cloud_water_content = "ExpectedValue"
#
#[retrieval_output.latent_heat]
#latent_heating = "ExpectedValue"
31 changes: 20 additions & 11 deletions gprof_nn/data/cloudsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from calendar import monthrange
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from datetime import datetime, timedelta
import logging
from pathlib import Path
from typing import Tuple
Expand All @@ -32,25 +32,21 @@
from pansat.utils import resample_data
from rich.progress import track, Progress

from gprof_nn.statistics import TrainingDataStats
from gprof_nn.sensors import Sensor
from gprof_nn.data.utils import (
run_preprocessor,
upsample_data,
add_cpcir_data,
calculate_obs_properties,
extract_scenes,
mask_invalid_values
mask_invalid_values,
PANSAT_PRODUCTS
)


LOGGER = logging.getLogger(__name__)

# pansat products for each sensor.
PRODUCTS = {
"gmi": (l1c_gpm_gmi,),
"atms": (l1c_npp_atms, l1c_noaa20_atms),
"amsr2": (l1c_gcomw1_amsr2,)
}

UPSAMPLING_FACTORS = {
"gmi": (3, 1),
Expand Down Expand Up @@ -184,6 +180,15 @@ def extract_cloudsat_scenes(
valid_field[~valid] = np.nan
input_data["valid"] = (("scans", "pixels"), valid_field)

pflag = input_data["precip_flag"].data
surface_precip = input_data["surface_precip"].data
surface_precip_snow = input_data["surface_precip_snow"].data
total_precip = np.nan * np.zeros_like(surface_precip)
total_precip[pflag == 0] = 0.0
total_precip[surface_precip > 0] = surface_precip[surface_precip > 0]
total_precip[surface_precip_snow > 0] = surface_precip[surface_precip_snow > 0]
input_data["total_precip"] = (("scans", "pixels"), total_precip)

input_data["input_observations"] = input_obs.observations.rename({"channels": "all_channels"})
input_data["input_meta_data"] = input_obs.meta_data.rename({"channels": "all_channels"})
mask_invalid_values(input_data)
Expand Down Expand Up @@ -226,11 +231,15 @@ def extract_cloudsat_scenes(
"snow_water_path",
"cloud_liquid_water_path",
"surface_precip",
"total_precip",
]:
encodings[var] = {"dtype": "float32", "zlib": True}

stats = TrainingDataStats(output_path)

scene_ind = 0
for scene in scenes:
stats.track(scene, valid_var="total_precip")
scene = scene.drop_vars("valid")
start_time = target_granule.time_range.start
start_str = start_time.strftime("%Y%m%d%H%M%S")
Expand Down Expand Up @@ -258,7 +267,7 @@ def extract_samples(
output_path: The path to which to write the extracted training scenes.
scene_size: The size of the training scenes to extract.
"""
input_products = PRODUCTS[sensor.name.lower()]
input_products = PANSAT_PRODUCTS[sensor.name.lower()]
target_product = l2c_rain_profile
for input_product in input_products:
input_recs = input_product.get(TimeRange(start_time, end_time))
Expand Down Expand Up @@ -328,7 +337,7 @@ def cli(
if n_processes is None:
for day in track(days):
start_time = datetime(year, month, day)
end_time = datetime(year, month, day + 1)
end_time = datetime(year, month, day) + timedelta(days=1)
extract_samples(
sensor,
start_time,
Expand All @@ -341,7 +350,7 @@ def cli(
tasks = []
for day in days:
start_time = datetime(year, month, day)
end_time = datetime(year, month, day + 1)
end_time = datetime(year, month, day) + timedelta(days=1)
tasks.append(
pool.submit(
extract_samples,
Expand Down
2 changes: 1 addition & 1 deletion gprof_nn/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def has_preprocessor():

# Dictionary mapping sensor IDs to preprocessor executables.
PREPROCESSOR_EXECUTABLES = {
"GMI": "gprof2023pp_GMI_L1C",
"GMI": "gprof2024pp_GMI_L1C",
"MHS": "gprof2023pp_MHS_L1C",
"TMIPR": "gprof2021pp_TMI_L1C",
"TMIPO": "gprof2021pp_TMI_L1C",
Expand Down
47 changes: 32 additions & 15 deletions gprof_nn/data/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from calendar import monthrange
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
from datetime import datetime, timedelta
import logging
import multiprocessing
from pathlib import Path
Expand Down Expand Up @@ -107,8 +107,7 @@ def extract_pretraining_scenes(
input_data[var].data[invalid] = np.nan

upsampling_factors = UPSAMPLING_FACTORS[input_sensor.name.lower()]
if max(upsampling_factors) > 1:
input_data = upsample_data(input_data, upsampling_factors)
input_data = upsample_data(input_data, upsampling_factors)
input_data = add_cpcir_data(input_data)

rof_in = RADIUS_OF_INFLUENCE[input_sensor.name.lower()]
Expand Down Expand Up @@ -138,6 +137,12 @@ def extract_pretraining_scenes(
tbs[tbs < 0] = np.nan
valid *= np.isfinite(tbs).any(0)
training_data["valid"] = (("scans", "pixels"), np.zeros_like(valid, dtype="float32"))

scan_time_input = input_obs.scan_time
scan_time_target = input_obs.scan_time
time_diff = scan_time_input - scan_time_target
valid *= np.abs(time_diff.data) < np.timedelta64(15, "m")

training_data.valid.data[~valid] = np.nan

scenes = extract_scenes(
Expand All @@ -163,15 +168,22 @@ def extract_pretraining_scenes(
"two_meter_temperature": {"dtype": "uint16", "zlib": True, "scale_factor": 0.1, "_FillValue": uint16_max},
"total_column_water_vapor": {"dtype": "float32", "zlib": True},
"leaf_area_index": {"dtype": "float32", "zlib": True},
"land_fraction": {"dtype": "int8", "zlib": True},
"ice_fraction": {"dtype": "int8", "zlib": True},
"land_fraction": {"dtype": "int8", "zlib": True, "_FillValue": -1},
"ice_fraction": {"dtype": "int8", "zlib": True, "_FillValue": -1},
"elevation": {"dtype": "uint16", "zlib": True, "scale_factor": 0.5, "_FillValue": uint16_max},
"ir_observations": {"dtype": "uint16", "zlib": True, "scale_factor": 0.01, "_FillValue": uint16_max},
}

for var in training_data:
print(var, training_data[var].dtype)

scene_ind = 0
for scene in scenes:
scene = scene.drop_vars(["valid"])
meta = scene["input_meta_data"].data
meta[meta < 0] = np.nan
meta = scene["target_meta_data"].data
meta[meta < 0] = np.nan
start_time = target_granule.time_range.start
start_str = start_time.strftime("%Y%m%d%H%M%S")
end_time = target_granule.time_range.end
Expand Down Expand Up @@ -261,14 +273,19 @@ def extract_samples(
target_index = Index.index(target_product, target_recs)
matches = find_matches(input_index, target_index, np.timedelta64(15, "m"))
for match in matches:
extract_pretraining_scenes(
input_sensor,
target_sensor,
match,
output_path,
scene_size=scene_size,
)

try:
extract_pretraining_scenes(
input_sensor,
target_sensor,
match,
output_path,
scene_size=scene_size,
)
except Exception:
LOGGER.exception(
"Encountered an error when extracting training data for match %s",
match[0]
)


def process_l1c_file(
Expand Down Expand Up @@ -441,7 +458,7 @@ def cli(
if n_processes is None:
for day in track(days):
start_time = datetime(year, month, day)
end_time = datetime(year, month, day + 1)
end_time = start_time + timedelta(days=1)
extract_samples(
input_sensor,
target_sensor,
Expand All @@ -455,7 +472,7 @@ def cli(
tasks = []
for day in days:
start_time = datetime(year, month, day)
end_time = datetime(year, month, day)
end_time = start_time + timedelta(days=1)
tasks.append(
pool.submit(
extract_samples,
Expand Down
Loading

0 comments on commit 8b6c18a

Please sign in to comment.