diff --git a/gprof_nn/data/cloudsat.py b/gprof_nn/data/cloudsat.py index f2a0aaf..958d92b 100644 --- a/gprof_nn/data/cloudsat.py +++ b/gprof_nn/data/cloudsat.py @@ -41,6 +41,7 @@ calculate_obs_properties, extract_scenes, mask_invalid_values, + RADIUS_OF_INFLUENCE, PANSAT_PRODUCTS ) @@ -66,6 +67,7 @@ def extract_cloudsat_scenes( match: Tuple[Granule, Tuple[Granule]], output_path: Path, scene_size: Tuple[int, int], + high_res: bool = False ) -> None: """ Extract training scenes between a GPM sensor and CloudSat observations. @@ -76,6 +78,7 @@ def extract_cloudsat_scenes( retrievals. output_path: The path to which to write the extracted training scenes. scene_size: The size of the training scenes to extract. + high_res: Whether to upsample data to ~ 5 km resolution. """ input_granule, target_granules = match target_granules = merge_granules(sorted(list(target_granules))) @@ -87,9 +90,10 @@ def extract_cloudsat_scenes( invalid = input_data[var].data < -1_000 input_data[var].data[invalid] = np.nan - upsampling_factors = UPSAMPLING_FACTORS[sensor.name.lower()] - if max(upsampling_factors) > 1: - input_data = upsample_data(input_data, upsampling_factors) + if high_res: + upsampling_factors = UPSAMPLING_FACTORS[sensor.name.lower()] + if max(upsampling_factors) > 1: + input_data = upsample_data(input_data, upsampling_factors) input_data = add_cpcir_data(input_data) lons = input_data.longitude.data @@ -138,7 +142,13 @@ def extract_cloudsat_scenes( cs_data["levels"] = (("levels",), levels) cs_data = cs_data.drop_dims("bins") - cs_data_r = resample_data(cs_data.transpose("rays", "levels"), swath, new_dims=(("scans", "pixels"))) + cs_data_r = resample_data( + cs_data.transpose("rays", "levels"), + swath, + new_dims=(("scans", "pixels")), + unique=True, + radius_of_influence=rof_in / 3 if high_res else rof_in + ) input_data["surface_precip"] = ( ("scans", "pixels"), @@ -195,10 +205,10 @@ def extract_cloudsat_scenes( scenes = extract_scenes( input_data, - n_scans=128, - n_pixels=128, + n_scans=scene_size[0], + n_pixels=scene_size[1], overlapping=True, - min_valid=100, + min_valid=scene_size[0] // 2, reference_var="valid", offset=50 ) @@ -256,6 +266,7 @@ def extract_samples( end_time: np.datetime64, output_path: Path, scene_size: Tuple[int, int] = (64, 64), + high_res: bool = False ) -> None: """ Extract GPM-CloudSat training scenes. @@ -266,6 +277,7 @@ def extract_samples( end_time: The end of the time period for which to extract training data. output_path: The path to which to write the extracted training scenes. scene_size: The size of the training scenes to extract. + high_res: Whether to extract samples at high resolution. """ input_products = PANSAT_PRODUCTS[sensor.name.lower()] target_product = l2c_rain_profile @@ -282,6 +294,7 @@ def extract_samples( match, output_path, scene_size=scene_size, + high_res=high_res ) except Exception: LOGGER.exception( @@ -295,7 +308,8 @@ def extract_samples( @click.argument("days", nargs=-1, type=int) @click.argument("output_path") @click.option("--n_processes", default=None, type=int) -@click.option("--scene_size", type=tuple, default=(64, 64)) +@click.option("--scene_size", type=str, default=(64, 64)) +@click.option("--high_res", type=bool, default=False) def cli( sensor: Sensor, year: int, @@ -304,9 +318,10 @@ def cli( output_path: Path, n_processes: int, scene_size: Tuple[int, int] = (64, 64), + high_res: bool = False ) -> None: """ - Extract CloudSat scenes data for SATFORMER training. + Extract CloudSat scenes data for GPROF-NN and GPROF-NN HR training. Args: sensor: The name of the GPM sensor. @@ -315,6 +330,7 @@ def cli( days: A list of the days of the month for which to extract the training data. output_path: The path to which to write the training data. n_processes: The number of processes to use for parallel processing + high_res: Whether to extract samples at high resolution. """ from gprof_nn import sensors @@ -334,6 +350,11 @@ def cli( LOGGER.error("The 'output' argument must point to a directory.") return 1 + scene_size = tuple(list(map(int, scene_size.split(",")))) + if len(scene_size) == 1: + scene_size = (scene_size[0],) * 2 + print(scene_size) + if n_processes is None: for day in track(days): start_time = datetime(year, month, day) @@ -344,6 +365,7 @@ def cli( end_time, output_path=output_path, scene_size=scene_size, + high_res=high_res ) else: pool = ProcessPoolExecutor(max_workers=n_processes) @@ -359,6 +381,7 @@ def cli( end_time, output_path=output_path, scene_size=scene_size, + high_res=high_res ) ) diff --git a/gprof_nn/data/preprocessor.py b/gprof_nn/data/preprocessor.py index cc063d9..f69ccee 100644 --- a/gprof_nn/data/preprocessor.py +++ b/gprof_nn/data/preprocessor.py @@ -717,7 +717,7 @@ def has_preprocessor(): # Dictionary mapping sensor IDs to preprocessor executables. PREPROCESSOR_EXECUTABLES = { - "GMI": "gprof2024pp_GMI_L1C", + "GMI": "gprof2023pp_GMI_L1C", "MHS": "gprof2023pp_MHS_L1C", "TMIPR": "gprof2021pp_TMI_L1C", "TMIPO": "gprof2021pp_TMI_L1C", @@ -734,7 +734,7 @@ def has_preprocessor(): ("GMI", "SSMIS"): "gprof2021pp_GMI_SSMIS_L1C", ("GMI", "AMSR2"): "gprof2023pp_GMI_L1C", ("GMI", "AMSRE"): "gprof2021pp_GMI_AMSRE_L1C", - ("GMI", "ATMS"): "gprof2021pp_GMI_ATMS_L1C", + ("GMI", "ATMS"): "gprof2023pp_GMI_L1C", ("GMI", "TMS"): "gprof2023pp_GMI_L1C", } diff --git a/gprof_nn/data/pretraining.py b/gprof_nn/data/pretraining.py index 8c6ad09..1af55a6 100644 --- a/gprof_nn/data/pretraining.py +++ b/gprof_nn/data/pretraining.py @@ -44,7 +44,8 @@ upsample_data, add_cpcir_data, calculate_obs_properties, - mask_invalid_values + mask_invalid_values, + RADIUS_OF_INFLUENCE ) from gprof_nn.data.l1c import L1CFile from gprof_nn.logging import ( @@ -70,11 +71,6 @@ "atms": (3, 3,), "amsr2": (1, 1) } -RADIUS_OF_INFLUENCE = { - "gmi": 20e3, - "atms": 100e3, - "amsr2": 10e3 -} def extract_pretraining_scenes( @@ -206,40 +202,101 @@ def __getitem__(self, index: int) -> int: def load_data(self, ind: int) -> Tuple[Dict[str, torch.Tensor], str, xr.Dataset]: - input_granule, target_granules = self.inputs[ind] + inpt_granule, target_granules = self.inputs[ind] target_granule = sorted(list(target_granules))[0] - input_data = run_preprocessor(input_granule) - input_obs = calculate_obs_properties(input_data, input_granule, radius_of_influence=self.radius_of_influence) - target_obs = calculate_obs_properties(input_data, target_granule, radius_of_influence=self.radius_of_influence) + inpt_file = L1CFile(inpt_granule.file_record.local_path) + inpt_sensor = inpt_file.sensor + targ_file = L1CFile(target_granule.file_record.local_path) + targ_sensor = targ_file.sensor + + inpt_data = run_preprocessor(inpt_granule) + upsampling_factors = UPSAMPLING_FACTORS[inpt_sensor.name.lower()] + inpt_data = upsample_data(inpt_data, upsampling_factors) + inpt_data = add_cpcir_data(inpt_data) + roi_inpt = RADIUS_OF_INFLUENCE[inpt_sensor.name.lower()] + inpt_observations = calculate_obs_properties( + inpt_data, + inpt_granule, + radius_of_influence=roi_inpt + ) + roi_targ = RADIUS_OF_INFLUENCE[targ_sensor.name.lower()] + target_observations = calculate_obs_properties( + inpt_data, + target_granule, + radius_of_influence=roi_targ + ) + + lons = inpt_data.longitude.data + valid = np.isfinite(inpt_data.longitude.data) + + inpt_obs = inpt_observations.observations.data + inpt_meta = inpt_observations.meta_data.data + + obs_in = [] + meta_in = [] + for ind, obs in enumerate(inpt_obs): + obs[..., ~valid] = np.nan + valid = np.isfinite(obs) + mean = np.mean(obs[valid]) + std = np.std(obs[valid]) + obs_n = (obs - mean) / std + obs = np.stack([ + np.ones_like(obs_n) * mean, + np.ones_like(obs_n) * std, + obs_n + ]) + obs_in.append(torch.tensor(obs)) + meta = inpt_meta[ind] + meta[..., ~valid] = np.nan + meta_in.append(torch.tensor(inpt_meta[ind])) + + + obs_in = torch.stack(obs_in, 1)[None] + meta_in = torch.stack(meta_in, 1)[None] + obs_in_mask = torch.isnan(obs_in).all(1).all(-1).all(-1) + + inpt = { + "observations": obs_in, + "input_observation_props": meta_in, + "input_observation_mask": obs_in_mask, + } + + anc_vars = [ + "two_meter_temperature", + "total_column_water_vapor", + "leaf_area_index", + "land_fraction", + "ice_fraction", + "elevation", + "ir_observations", + ] + for anc_var in anc_vars: + anc_data = torch.tensor(inpt_data[anc_var].data).to(dtype=torch.float32) + if anc_data.dim() < 3: + anc_data = anc_data[None] + anc_data = anc_data[None, :, None] + anc_mask = torch.isnan(anc_data).all()[None, None] + inpt[anc_var] = anc_data + inpt[anc_var + "_mask"] = anc_mask + + props = torch.tensor(target_observations["meta_data"].data)[None] + inpt["output_observation_props"] = props.transpose(1, 2) training_data = xr.Dataset({ - "latitude": input_data.latitude, - "longitude": input_data.longitude, - "input_observations": input_obs.observations.rename(channels="input_channels"), - "input_meta_data": input_obs.meta_data.rename(channels="input_channels"), - "target_observations": target_obs.observations.rename(channels="target_channels"), - "target_meta_data": target_obs.meta_data.rename(channels="target_channels"), + "latitude": inpt_data.latitude, + "longitude": inpt_data.longitude, + "input_observations": inpt_observations.observations.rename(channels="input_channels"), + "input_meta_data": inpt_observations.meta_data.rename(channels="input_channels"), + "target_observations": target_observations.observations.rename(channels="target_channels"), + "target_meta_data": target_observations.meta_data.rename(channels="target_channels"), }) - tbs = training_data.input_observations.data - tbs[tbs < 0] = np.nan - n_seq_in = tbs.shape[0] - mask = np.all(np.isnan(tbs), axis=(1, 2)) - tbs = training_data.target_observations.data - tbs[tbs < 0] = np.nan - n_seq_out = tbs.shape[0] - - input_data = { - "observations": torch.tensor(training_data.input_observations.data)[None, None], - "input_observation_mask": torch.tensor(mask, dtype=torch.bool)[None], - "input_observation_props": torch.tensor(training_data.input_meta_data.data)[None].transpose(1, 2), - "dropped_observation_props": torch.tensor(training_data.input_meta_data.data)[11:][None].transpose(1, 2), - "output_observation_props": torch.tensor(training_data.target_meta_data.data)[None].transpose(1, 2), - } filename = "match_" + target_granule.time_range.start.strftime("%Y%m%d%H%M%s") + ".nc" - return input_data, filename, training_data + print({key: tensor.shape for key,tensor in inpt.items()}) + + return inpt, filename, training_data def extract_samples( diff --git a/gprof_nn/data/sim.py b/gprof_nn/data/sim.py index caf15fa..e55071d 100644 --- a/gprof_nn/data/sim.py +++ b/gprof_nn/data/sim.py @@ -24,6 +24,7 @@ import numpy as np import pandas as pd from pansat import TimeRange, Granule +from pansat.time import to_datetime from pansat.products.satellite.gpm import l1c_r_gpm_gmi from pykdtree.kdtree import KDTree from rich.progress import Progress @@ -67,7 +68,6 @@ calculate_obs_properties, RADIUS_OF_INFLUENCE, calculate_polarization_weights, - decompress_scene ) BEAM_WIDTHS = { @@ -750,7 +750,11 @@ def load_input_data_xtrack(self): def load_input_data_conical(self): upsampling_factors = UPSAMPLING_FACTORS["gmi"] - input_data = upsample_data(self.data, upsampling_factors) + restrict_vars = [ + name for name in list(self.data.variables.keys()) + list(self.data.dims) + if "pixels_center" not in self.data[name].dims + ] + input_data = upsample_data(self.data[restrict_vars], upsampling_factors) input_data = add_cpcir_data(input_data) central_time = input_data.scan_time.data[0] + (input_data.scan_time.data[-1] - input_data.scan_time.data[0]) // 2 @@ -936,12 +940,8 @@ def process_sim_file( ) vars = list(data.variables) + list(data.dims) - data = decompress_scene(data, vars) - - if satformer_model is not None: - simulate_tbs_satformer(satformer_model, data, sensor) - if sensor.name != "GMI": + if sensor.name.lower() != "gmi" and satformer_model is not None: lock = FileLock("cuda.lock") with lock: simulate_tbs_satformer( @@ -949,6 +949,9 @@ def process_sim_file( data, sensor, ) + torch.cuda.synchronize() + torch.cuda.empty_cache() + if lonlat_bounds is not None: lon_ll, lat_ll, lon_ur, lat_ur = lonlat_bounds @@ -974,8 +977,8 @@ def process_files( output_path_1d: Path, output_path_3d: Path, n_processes: int = 1, - start_time: Optional[np.datetime64] = None, - end_time: Optional[np.datetime64] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, split: Optional[str] = None, include_cmb_precip: bool = False, lonlat_bounds: Optional[Tuple[float, float, float, float]] = None, @@ -1165,7 +1168,7 @@ def cli(sensor: Sensor, if start_time is not None: try: - start_time = np.datetime64(start_time) + start_time = to_datetime(np.datetime64(start_time)) except ValueError: LOGGER.error( "Coud not parse 'start_time' argument as numpy.datetime64 object. " @@ -1176,7 +1179,7 @@ def cli(sensor: Sensor, if end_time is not None: try: - end_time = np.datetime64(end_time) + end_time = to_datetime(np.datetime64(end_time)) except ValueError: LOGGER.error( "Coud not parse 'end_time' argument as numpy.datetime64 object. " diff --git a/gprof_nn/plotting.py b/gprof_nn/plotting.py index cb87f52..83019cb 100644 --- a/gprof_nn/plotting.py +++ b/gprof_nn/plotting.py @@ -6,12 +6,14 @@ Utility functions for plotting. """ import pathlib +from typing import List import cartopy.crs as ccrs from matplotlib import rc import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.patches import Rectangle +from matplotlib.ticker import FixedLocator from matplotlib.colors import to_rgba, to_hex, LogNorm from matplotlib.cm import ScalarMappable import numpy as np @@ -477,3 +479,30 @@ def add_swath_edges( coords = np.stack([x, y, z], axis=-1) line_r = pv.Spline(coords) scene.add_mesh(line_r, color="k", line_width=2) + + +def add_ticks( + ax: plt.Axes, + lons: List[float], + lats: list[float], + left=True, + bottom=True +) -> None: + import cartopy.crs as ccrs + """ + Add tick to cartopy Axes object. + + Args: + ax: The Axes object to which to add the ticks. + lons: The longitude coordinate at which to add ticks. + lats: The latitude coordinate at which to add ticks. + left: Whether or not to draw ticks on the y-axis. + bottom: Whether or not to draw ticks on the x-axis. + """ + gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, linewidth=0, color='none') + gl.top_labels = False + gl.right_labels = False + gl.left_labels = left + gl.bottom_labels = bottom + gl.xlocator = FixedLocator(lons) + gl.ylocator = FixedLocator(lats) diff --git a/gprof_nn/sensors/amsr2.toml b/gprof_nn/sensors/amsr2.toml index cc47b18..c7c8c33 100644 --- a/gprof_nn/sensors/amsr2.toml +++ b/gprof_nn/sensors/amsr2.toml @@ -6,7 +6,7 @@ l1c_file_prefix = "1C.GCOMW1.AMSR2" [viewing_geometry] kind = "Conical" altitude = 726.9e3 -earth_incidence_angle = 55.0 +earth_incidence_angle = 55.32936 scan_range = 140.0 pixels_per_scan = 492 scan_offset = 5e3