Skip to content

Commit

Permalink
Fixes for CloudSat collocation extraction.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Nov 29, 2024
1 parent 4107677 commit e74287f
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 55 deletions.
41 changes: 32 additions & 9 deletions gprof_nn/data/cloudsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
calculate_obs_properties,
extract_scenes,
mask_invalid_values,
RADIUS_OF_INFLUENCE,
PANSAT_PRODUCTS
)

Expand All @@ -66,6 +67,7 @@ def extract_cloudsat_scenes(
match: Tuple[Granule, Tuple[Granule]],
output_path: Path,
scene_size: Tuple[int, int],
high_res: bool = False
) -> None:
"""
Extract training scenes between a GPM sensor and CloudSat observations.
Expand All @@ -76,6 +78,7 @@ def extract_cloudsat_scenes(
retrievals.
output_path: The path to which to write the extracted training scenes.
scene_size: The size of the training scenes to extract.
high_res: Whether to upsample data to ~ 5 km resolution.
"""
input_granule, target_granules = match
target_granules = merge_granules(sorted(list(target_granules)))
Expand All @@ -87,9 +90,10 @@ def extract_cloudsat_scenes(
invalid = input_data[var].data < -1_000
input_data[var].data[invalid] = np.nan

upsampling_factors = UPSAMPLING_FACTORS[sensor.name.lower()]
if max(upsampling_factors) > 1:
input_data = upsample_data(input_data, upsampling_factors)
if high_res:
upsampling_factors = UPSAMPLING_FACTORS[sensor.name.lower()]
if max(upsampling_factors) > 1:
input_data = upsample_data(input_data, upsampling_factors)
input_data = add_cpcir_data(input_data)

lons = input_data.longitude.data
Expand Down Expand Up @@ -138,7 +142,13 @@ def extract_cloudsat_scenes(
cs_data["levels"] = (("levels",), levels)
cs_data = cs_data.drop_dims("bins")

cs_data_r = resample_data(cs_data.transpose("rays", "levels"), swath, new_dims=(("scans", "pixels")))
cs_data_r = resample_data(
cs_data.transpose("rays", "levels"),
swath,
new_dims=(("scans", "pixels")),
unique=True,
radius_of_influence=rof_in / 3 if high_res else rof_in
)

input_data["surface_precip"] = (
("scans", "pixels"),
Expand Down Expand Up @@ -195,10 +205,10 @@ def extract_cloudsat_scenes(

scenes = extract_scenes(
input_data,
n_scans=128,
n_pixels=128,
n_scans=scene_size[0],
n_pixels=scene_size[1],
overlapping=True,
min_valid=100,
min_valid=scene_size[0] // 2,
reference_var="valid",
offset=50
)
Expand Down Expand Up @@ -256,6 +266,7 @@ def extract_samples(
end_time: np.datetime64,
output_path: Path,
scene_size: Tuple[int, int] = (64, 64),
high_res: bool = False
) -> None:
"""
Extract GPM-CloudSat training scenes.
Expand All @@ -266,6 +277,7 @@ def extract_samples(
end_time: The end of the time period for which to extract training data.
output_path: The path to which to write the extracted training scenes.
scene_size: The size of the training scenes to extract.
high_res: Whether to extract samples at high resolution.
"""
input_products = PANSAT_PRODUCTS[sensor.name.lower()]
target_product = l2c_rain_profile
Expand All @@ -282,6 +294,7 @@ def extract_samples(
match,
output_path,
scene_size=scene_size,
high_res=high_res
)
except Exception:
LOGGER.exception(
Expand All @@ -295,7 +308,8 @@ def extract_samples(
@click.argument("days", nargs=-1, type=int)
@click.argument("output_path")
@click.option("--n_processes", default=None, type=int)
@click.option("--scene_size", type=tuple, default=(64, 64))
@click.option("--scene_size", type=str, default=(64, 64))
@click.option("--high_res", type=bool, default=False)
def cli(
sensor: Sensor,
year: int,
Expand All @@ -304,9 +318,10 @@ def cli(
output_path: Path,
n_processes: int,
scene_size: Tuple[int, int] = (64, 64),
high_res: bool = False
) -> None:
"""
Extract CloudSat scenes data for SATFORMER training.
Extract CloudSat scenes data for GPROF-NN and GPROF-NN HR training.
Args:
sensor: The name of the GPM sensor.
Expand All @@ -315,6 +330,7 @@ def cli(
days: A list of the days of the month for which to extract the training data.
output_path: The path to which to write the training data.
n_processes: The number of processes to use for parallel processing
high_res: Whether to extract samples at high resolution.
"""
from gprof_nn import sensors

Expand All @@ -334,6 +350,11 @@ def cli(
LOGGER.error("The 'output' argument must point to a directory.")
return 1

scene_size = tuple(list(map(int, scene_size.split(","))))
if len(scene_size) == 1:
scene_size = (scene_size[0],) * 2
print(scene_size)

if n_processes is None:
for day in track(days):
start_time = datetime(year, month, day)
Expand All @@ -344,6 +365,7 @@ def cli(
end_time,
output_path=output_path,
scene_size=scene_size,
high_res=high_res
)
else:
pool = ProcessPoolExecutor(max_workers=n_processes)
Expand All @@ -359,6 +381,7 @@ def cli(
end_time,
output_path=output_path,
scene_size=scene_size,
high_res=high_res
)
)

Expand Down
4 changes: 2 additions & 2 deletions gprof_nn/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def has_preprocessor():

# Dictionary mapping sensor IDs to preprocessor executables.
PREPROCESSOR_EXECUTABLES = {
"GMI": "gprof2024pp_GMI_L1C",
"GMI": "gprof2023pp_GMI_L1C",
"MHS": "gprof2023pp_MHS_L1C",
"TMIPR": "gprof2021pp_TMI_L1C",
"TMIPO": "gprof2021pp_TMI_L1C",
Expand All @@ -734,7 +734,7 @@ def has_preprocessor():
("GMI", "SSMIS"): "gprof2021pp_GMI_SSMIS_L1C",
("GMI", "AMSR2"): "gprof2023pp_GMI_L1C",
("GMI", "AMSRE"): "gprof2021pp_GMI_AMSRE_L1C",
("GMI", "ATMS"): "gprof2021pp_GMI_ATMS_L1C",
("GMI", "ATMS"): "gprof2023pp_GMI_L1C",
("GMI", "TMS"): "gprof2023pp_GMI_L1C",
}

Expand Down
121 changes: 89 additions & 32 deletions gprof_nn/data/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
upsample_data,
add_cpcir_data,
calculate_obs_properties,
mask_invalid_values
mask_invalid_values,
RADIUS_OF_INFLUENCE
)
from gprof_nn.data.l1c import L1CFile
from gprof_nn.logging import (
Expand All @@ -70,11 +71,6 @@
"atms": (3, 3,),
"amsr2": (1, 1)
}
RADIUS_OF_INFLUENCE = {
"gmi": 20e3,
"atms": 100e3,
"amsr2": 10e3
}


def extract_pretraining_scenes(
Expand Down Expand Up @@ -206,40 +202,101 @@ def __getitem__(self, index: int) -> int:

def load_data(self, ind: int) -> Tuple[Dict[str, torch.Tensor], str, xr.Dataset]:

input_granule, target_granules = self.inputs[ind]
inpt_granule, target_granules = self.inputs[ind]
target_granule = sorted(list(target_granules))[0]

input_data = run_preprocessor(input_granule)
input_obs = calculate_obs_properties(input_data, input_granule, radius_of_influence=self.radius_of_influence)
target_obs = calculate_obs_properties(input_data, target_granule, radius_of_influence=self.radius_of_influence)
inpt_file = L1CFile(inpt_granule.file_record.local_path)
inpt_sensor = inpt_file.sensor
targ_file = L1CFile(target_granule.file_record.local_path)
targ_sensor = targ_file.sensor

inpt_data = run_preprocessor(inpt_granule)
upsampling_factors = UPSAMPLING_FACTORS[inpt_sensor.name.lower()]
inpt_data = upsample_data(inpt_data, upsampling_factors)
inpt_data = add_cpcir_data(inpt_data)
roi_inpt = RADIUS_OF_INFLUENCE[inpt_sensor.name.lower()]
inpt_observations = calculate_obs_properties(
inpt_data,
inpt_granule,
radius_of_influence=roi_inpt
)
roi_targ = RADIUS_OF_INFLUENCE[targ_sensor.name.lower()]
target_observations = calculate_obs_properties(
inpt_data,
target_granule,
radius_of_influence=roi_targ
)

lons = inpt_data.longitude.data
valid = np.isfinite(inpt_data.longitude.data)

inpt_obs = inpt_observations.observations.data
inpt_meta = inpt_observations.meta_data.data

obs_in = []
meta_in = []
for ind, obs in enumerate(inpt_obs):
obs[..., ~valid] = np.nan
valid = np.isfinite(obs)
mean = np.mean(obs[valid])
std = np.std(obs[valid])
obs_n = (obs - mean) / std
obs = np.stack([
np.ones_like(obs_n) * mean,
np.ones_like(obs_n) * std,
obs_n
])
obs_in.append(torch.tensor(obs))
meta = inpt_meta[ind]
meta[..., ~valid] = np.nan
meta_in.append(torch.tensor(inpt_meta[ind]))


obs_in = torch.stack(obs_in, 1)[None]
meta_in = torch.stack(meta_in, 1)[None]
obs_in_mask = torch.isnan(obs_in).all(1).all(-1).all(-1)

inpt = {
"observations": obs_in,
"input_observation_props": meta_in,
"input_observation_mask": obs_in_mask,
}

anc_vars = [
"two_meter_temperature",
"total_column_water_vapor",
"leaf_area_index",
"land_fraction",
"ice_fraction",
"elevation",
"ir_observations",
]
for anc_var in anc_vars:
anc_data = torch.tensor(inpt_data[anc_var].data).to(dtype=torch.float32)
if anc_data.dim() < 3:
anc_data = anc_data[None]
anc_data = anc_data[None, :, None]
anc_mask = torch.isnan(anc_data).all()[None, None]
inpt[anc_var] = anc_data
inpt[anc_var + "_mask"] = anc_mask

props = torch.tensor(target_observations["meta_data"].data)[None]
inpt["output_observation_props"] = props.transpose(1, 2)

training_data = xr.Dataset({
"latitude": input_data.latitude,
"longitude": input_data.longitude,
"input_observations": input_obs.observations.rename(channels="input_channels"),
"input_meta_data": input_obs.meta_data.rename(channels="input_channels"),
"target_observations": target_obs.observations.rename(channels="target_channels"),
"target_meta_data": target_obs.meta_data.rename(channels="target_channels"),
"latitude": inpt_data.latitude,
"longitude": inpt_data.longitude,
"input_observations": inpt_observations.observations.rename(channels="input_channels"),
"input_meta_data": inpt_observations.meta_data.rename(channels="input_channels"),
"target_observations": target_observations.observations.rename(channels="target_channels"),
"target_meta_data": target_observations.meta_data.rename(channels="target_channels"),
})
tbs = training_data.input_observations.data
tbs[tbs < 0] = np.nan
n_seq_in = tbs.shape[0]
mask = np.all(np.isnan(tbs), axis=(1, 2))
tbs = training_data.target_observations.data
tbs[tbs < 0] = np.nan
n_seq_out = tbs.shape[0]

input_data = {
"observations": torch.tensor(training_data.input_observations.data)[None, None],
"input_observation_mask": torch.tensor(mask, dtype=torch.bool)[None],
"input_observation_props": torch.tensor(training_data.input_meta_data.data)[None].transpose(1, 2),
"dropped_observation_props": torch.tensor(training_data.input_meta_data.data)[11:][None].transpose(1, 2),
"output_observation_props": torch.tensor(training_data.target_meta_data.data)[None].transpose(1, 2),
}

filename = "match_" + target_granule.time_range.start.strftime("%Y%m%d%H%M%s") + ".nc"

return input_data, filename, training_data
print({key: tensor.shape for key,tensor in inpt.items()})

return inpt, filename, training_data


def extract_samples(
Expand Down
25 changes: 14 additions & 11 deletions gprof_nn/data/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
import pandas as pd
from pansat import TimeRange, Granule
from pansat.time import to_datetime
from pansat.products.satellite.gpm import l1c_r_gpm_gmi
from pykdtree.kdtree import KDTree
from rich.progress import Progress
Expand Down Expand Up @@ -67,7 +68,6 @@
calculate_obs_properties,
RADIUS_OF_INFLUENCE,
calculate_polarization_weights,
decompress_scene
)

BEAM_WIDTHS = {
Expand Down Expand Up @@ -750,7 +750,11 @@ def load_input_data_xtrack(self):
def load_input_data_conical(self):

upsampling_factors = UPSAMPLING_FACTORS["gmi"]
input_data = upsample_data(self.data, upsampling_factors)
restrict_vars = [
name for name in list(self.data.variables.keys()) + list(self.data.dims)
if "pixels_center" not in self.data[name].dims
]
input_data = upsample_data(self.data[restrict_vars], upsampling_factors)
input_data = add_cpcir_data(input_data)

central_time = input_data.scan_time.data[0] + (input_data.scan_time.data[-1] - input_data.scan_time.data[0]) // 2
Expand Down Expand Up @@ -936,19 +940,18 @@ def process_sim_file(
)

vars = list(data.variables) + list(data.dims)
data = decompress_scene(data, vars)

if satformer_model is not None:
simulate_tbs_satformer(satformer_model, data, sensor)

if sensor.name != "GMI":
if sensor.name.lower() != "gmi" and satformer_model is not None:
lock = FileLock("cuda.lock")
with lock:
simulate_tbs_satformer(
satformer_model,
data,
sensor,
)
torch.cuda.synchronize()
torch.cuda.empty_cache()


if lonlat_bounds is not None:
lon_ll, lat_ll, lon_ur, lat_ur = lonlat_bounds
Expand All @@ -974,8 +977,8 @@ def process_files(
output_path_1d: Path,
output_path_3d: Path,
n_processes: int = 1,
start_time: Optional[np.datetime64] = None,
end_time: Optional[np.datetime64] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
split: Optional[str] = None,
include_cmb_precip: bool = False,
lonlat_bounds: Optional[Tuple[float, float, float, float]] = None,
Expand Down Expand Up @@ -1165,7 +1168,7 @@ def cli(sensor: Sensor,

if start_time is not None:
try:
start_time = np.datetime64(start_time)
start_time = to_datetime(np.datetime64(start_time))
except ValueError:
LOGGER.error(
"Coud not parse 'start_time' argument as numpy.datetime64 object. "
Expand All @@ -1176,7 +1179,7 @@ def cli(sensor: Sensor,

if end_time is not None:
try:
end_time = np.datetime64(end_time)
end_time = to_datetime(np.datetime64(end_time))
except ValueError:
LOGGER.error(
"Coud not parse 'end_time' argument as numpy.datetime64 object. "
Expand Down
Loading

0 comments on commit e74287f

Please sign in to comment.