Skip to content

Commit

Permalink
Upsample low-res input, add ancillary data, and IR obs.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Oct 4, 2024
1 parent 67ec68e commit 9d7e9fe
Showing 1 changed file with 167 additions and 44 deletions.
211 changes: 167 additions & 44 deletions gprof_nn/data/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
gprof_nn.data.pretraining
=========================
The module provides functionality to extract data for unsupervised pre-training
of the GPROF-NN T model.
This module provides functionality to extract observation collocations between various sensors
of the GPM constellation and extract training samples suitable for training an observation
translator model.
"""
from calendar import monthrange
from concurrent.futures import ProcessPoolExecutor, as_completed
Expand All @@ -28,7 +29,8 @@
from pansat.products.satellite.gpm import (
l1c_gpm_gmi,
l1c_npp_atms,
l1c_gcomw1_amsr2
l1c_gcomw1_amsr2,
merged_ir
)
from pansat.utils import resample_data
from pyresample.geometry import SwathDefinition
Expand All @@ -51,26 +53,38 @@
LOGGER = logging.getLogger(__name__)


# pansat products for each sensor.
PRODUCTS = {
"gmi": (l1c_gpm_gmi,),
"atms": (l1c_npp_atms,),
"amsr2": (l1c_gcomw1_amsr2,)
}

UPSAMPLING_FACTORS = {
"gmi": (3, 1),
"atms": (3, 3,),
"amsr2": (1, 1)
}

POLARIZATIONS = {
"H": 0,
"QH": 1,
"V": 2,
"QV": 2,
}


BEAM_WIDTHS = {
"gmi": [1.75, 1.75, 1.0, 1.0, 0.9, 0.9, 0.9, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
"atms": [5.2, 5.2, 2.2, 1.1, 1.1, 1.1, 1.1, 1.1],
"amsr2": [1.2, 1.2, 0.65, 0.65, 0.75, 0.75, 0.35, 0.35, 0.15, 0.15, 0.15, 0.15],
}

RADIUS_OF_INFLUENCE = {
"gmi": 20e3,
"atms": 100e3,
"amsr2": 10e3
}


CHANNEL_REGEXP = re.compile("([\d\.\s\+\/-]*)\s*GHz\s*(\w*)-Pol")

Expand Down Expand Up @@ -155,15 +169,8 @@ def run_preprocessor(gpm_granule: Granule) -> xr.Dataset:
finally:
os.chdir(old_dir)

preprocessor_data = preprocessor_data.rename({
"scans": "scan",
"pixels": "pixel",
"channels": "channel_gprof",
"brightness_temperatures": "observations_gprof",
"earth_incidence_angle": "earth_incidence_angle_gprof"
})
invalid = preprocessor_data.observations_gprof.data < 0
preprocessor_data.observations_gprof.data[invalid] = np.nan
invalid = preprocessor_data.brightness_temperatures.data < 0
preprocessor_data.brightness_temperatures.data[invalid] = np.nan

return preprocessor_data

Expand All @@ -176,7 +183,7 @@ def calculate_angles(
sensor_alts: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calculate zenith and azimuth angles describing the observations geometry.
Calculate zenith and azimuth angles describing the observation geometry.
Args:
fp_lons: Array containing the longitude coordinates of the observation
Expand All @@ -197,20 +204,22 @@ def calculate_angles(
fp_lla = np.stack((fp_lons, fp_lats, np.zeros_like(fp_lons)), -1)
fp_ecef = lla_to_ecef(fp_lla)
local_up = fp_ecef / np.linalg.norm(fp_ecef, axis=-1, keepdims=True)
fp_east = fp_lla.copy()
fp_east[..., 0] = 0.1
fp_east = lla_to_ecef(fp_east)
fp_east = fp_east / np.linalg.norm(fp_east, axis=-1, keepdims=True)
local_north = np.cross(local_up, fp_east)

sensor_ecef = np.broadcast_to(sensor_ecef[..., None, :], fp_lla.shape)
fp_west = fp_lla.copy()
fp_west[..., 0] -= 0.1
fp_west = lla_to_ecef(fp_west) - fp_ecef
fp_west /= np.linalg.norm(fp_west, axis=-1, keepdims=True)
fp_north = fp_lla.copy()
fp_north[..., 1] += 0.1
fp_north = lla_to_ecef(fp_north) - fp_ecef
fp_north /= np.linalg.norm(fp_north, axis=-1, keepdims=True)

if sensor_ecef.ndim < fp_lla.ndim:
sensor_ecef = np.broadcast_to(sensor_ecef[..., None, :], fp_lla.shape)
los = sensor_ecef - fp_ecef
zenith = np.arccos((local_up * los).sum(-1) / np.linalg.norm(los, axis=-1))
proj = los - local_up * (local_up * los).sum(-1, keepdims=True)

azimuth = np.arccos((local_north * proj).sum(-1) / np.linalg.norm(proj, axis=-1))
mask = np.isclose(np.linalg.norm(proj, axis=-1), 0.0)
azimuth[mask] = 0.0
azimuth = np.arctan2((los * fp_west).sum(-1), (los * fp_north).sum(-1))
azimuth = np.nan_to_num(azimuth, nan=0.0)

return np.rad2deg(zenith), np.rad2deg(azimuth)

Expand All @@ -226,9 +235,8 @@ def calculate_obs_properties(
Args:
preprocessor_data: The preprocessor data to which to resample all
observaitons.
granule:
granule: A pansat granule defining the section of a orbit containing the overpass.
radius_of_influence: The radius of influence to use for resampling the L1C observations.
"""
lons = preprocessor_data.longitude.data
lats = preprocessor_data.latitude.data
Expand Down Expand Up @@ -319,6 +327,100 @@ def calculate_obs_properties(
})


def upsample_data(
data: xr.Dataset,
upsampling_factors: Tuple[int, int]
) -> xr.Dataset:
"""
Upsample preprocessor data along scans and pixels.
Args:
data: An xarray.Dataset containing preprocessor data.
upsampling_factors: A tuple describing the upsampling factors alon scans and pixels.
Return:
The preprocessor data upsampled by the given factors along scans and pixels.
"""
float_vars = [
"latitude", "longitude", "brightness_temperatures", "total_column_water_vapor", "two_meter_temperature",
"moisture_convergence", "leaf_area_index", "snow_depth", "land_fraction", "ice_fraction", "elevation",
]
scan_time = data["scan_time"]
data = data[float_vars]

n_scans = data.scans.size
n_scans_up = upsampling_factors[0] * n_scans
new_scans = np.linspace(data.scans[0], data.scans[-1], n_scans_up)

n_pixels = data.pixels.size
n_pixels_up = upsampling_factors[1] * n_pixels
new_pixels = np.linspace(data.pixels[0], data.pixels[-1], n_pixels_up)

data = data.interp(scans=new_scans, pixels=new_pixels).drop_vars(["pixels", "scans"])
scan_time_int = scan_time.astype(np.int64)
scan_time_new = scan_time_int.interp(scans=new_scans, method="nearest")
data["scan_time"] = (("scans",), scan_time_new.data.astype("datetime64[ns]"))

return data


def add_cpcir_data(
preprocessor_data: xr.Dataset,
) -> xr.Dataset:
"""
Add CPCIR 11um IR observations to the preprocessor data.
Args:
preprocessor_data: An xarray.Dataset containing the data from the preprocessor
Return:
The preprocessor data with an additional variable 'ir_observations' containing CPCIR 11 um
Tbs if available.
"""
scan_time_start = preprocessor_data.scan_time.data[0]
scan_time_end = preprocessor_data.scan_time.data[-1]
time_c = scan_time_start + 0.5 * (scan_time_end - scan_time_start)
time_range = TimeRange(time_c)
recs = merged_ir.get(time_range)

if len(recs) == 0:
preprocessor_data["ir_observations"] = (("scans", "pixels"), np.nan * np.zeros_like(preprocessor_data.longitude.data))
return preprocessor_data

with xr.open_dataset(recs[0].local_path) as cpcir_data:
lons = cpcir_data.lon.data
lats = cpcir_data.lat.data

lat_min = preprocessor_data.latitude.data.min()
lat_max = preprocessor_data.latitude.data.max()
inds = np.where((lat_min <= lats) * (lats <= lat_max))[0]
if len(inds) < 2:
lat_start, lat_end = 0, None
else:
lat_start, lat_end = inds[[0, -1]]

lon_min = preprocessor_data.longitude.data.min()
lon_max = preprocessor_data.longitude.data.max()
inds = np.where((lon_min <= lons) * (lons <= lon_max))[0]
if len(inds) < 2:
lon_start, lon_end = 0, None
else:
lon_start, lon_end = inds[[0, -1]]

cpcir_tbs = cpcir_data.Tb[{"lat": slice(lat_start, lat_end), "lon": slice(lon_start, lon_end)}]

scan_time = preprocessor_data.scan_time
scan_time, _ = xr.broadcast(scan_time, preprocessor_data.longitude)

cpcir_tbs = cpcir_tbs.interp(
lat = preprocessor_data.latitude,
lon = preprocessor_data.longitude,
).rename(time="ir_obs")
preprocessor_data["ir_observations"] = cpcir_tbs

return preprocessor_data




def extract_pretraining_scenes(
Expand All @@ -327,20 +429,44 @@ def extract_pretraining_scenes(
match: Tuple[Granule, Tuple[Granule]],
output_path: Path,
scene_size: Tuple[int, int],
radius_of_influence: float
) -> None:
"""
Extract training scenes from a match-up of two GPM sensors.
"""
input_granule, target_granules = match
target_granules = merge_granules(sorted(list(target_granules)))
for target_granule in target_granules:

input_data = run_preprocessor(input_granule)
input_obs = calculate_obs_properties(input_data, input_granule, radius_of_influence=radius_of_influence)
target_obs = calculate_obs_properties(input_data, target_granule, radius_of_influence=radius_of_influence)
for var in input_data:
if np.issubdtype(input_data[var].data.dtype, np.floating):
invalid = input_data[var].data < -1_000
input_data[var].data[invalid] = np.nan

upsampling_factors = UPSAMPLING_FACTORS[input_sensor.name.lower()]
if max(upsampling_factors) > 1:
input_data = upsample_data(input_data, upsampling_factors)
input_data = add_cpcir_data(input_data)

rof_in = RADIUS_OF_INFLUENCE[input_sensor.name.lower()]
rof_targ = RADIUS_OF_INFLUENCE[target_sensor.name.lower()]
input_obs = calculate_obs_properties(input_data, input_granule, radius_of_influence=rof_in)
target_obs = calculate_obs_properties(input_data, target_granule, radius_of_influence=rof_targ)


training_data = xr.Dataset({
"input_observations": input_obs.observations.rename(channels="input_channels"),
"input_meta_data": input_obs.meta_data.rename(channels="input_channels"),
"target_observations": target_obs.observations.rename(channels="target_channels"),
"target_meta_data": target_obs.meta_data.rename(channels="target_channels"),
"two_meter_temperature": input_data.two_meter_temperature,
"total_column_water_vapor": input_data.two_meter_temperature,
"leaf_area_index": input_data.leaf_area_index,
"land_fraction": input_data.land_fraction,
"ice_fraction": input_data.ice_fraction,
"elevation": input_data.elevation,
"ir_observations": input_data.ir_observations,
})

tbs = training_data.input_observations.data
Expand All @@ -354,9 +480,9 @@ def extract_pretraining_scenes(

scenes = extract_scenes(
training_data,
n_scans=64,
n_pixels=64,
overlapping=True,
n_scans=128,
n_pixels=128,
overlapping=False,
min_valid=50,
reference_var="valid",
)
Expand Down Expand Up @@ -416,14 +542,17 @@ def load_data(self, ind: int) -> Tuple[Dict[str, torch.Tensor], str, xr.Dataset]
tbs = training_data.input_observations.data
tbs[tbs < 0] = np.nan
n_seq_in = tbs.shape[0]
mask = np.all(np.isnan(tbs), axis=(1, 2))
tbs = training_data.target_observations.data
tbs[tbs < 0] = np.nan
n_seq_out = tbs.shape[0]

input_data = {
"input_observations": torch.tensor(training_data.input_observations.data)[None, None],
"input_meta": torch.tensor(training_data.input_meta_data.data)[None].transpose(1, 2),
"output_meta": torch.tensor(training_data.target_meta_data.data)[None].transpose(1, 2),
"observations": torch.tensor(training_data.input_observations.data)[None, None],
"input_observation_mask": torch.tensor(mask, dtype=torch.bool)[None],
"input_observation_props": torch.tensor(training_data.input_meta_data.data)[None].transpose(1, 2),
"dropped_observation_props": torch.tensor(training_data.input_meta_data.data)[11:][None].transpose(1, 2),
"output_observation_props": torch.tensor(training_data.target_meta_data.data)[None].transpose(1, 2),
}

filename = "match_" + target_granule.time_range.start.strftime("%Y%m%d%H%M%s") + ".nc"
Expand All @@ -438,7 +567,6 @@ def extract_samples(
end_time: np.datetime64,
output_path: Path,
scene_size: Tuple[int, int] = (64, 64),
radius_of_influence: float = 20e3
) -> None:
"""
Extract pretraining sensors.
Expand All @@ -460,7 +588,6 @@ def extract_samples(
match,
output_path,
scene_size=scene_size,
radius_of_influence=radius_of_influence
)


Expand Down Expand Up @@ -587,7 +714,6 @@ def process_l1c_files(
@click.argument("output_path")
@click.option("--n_processes", default=None, type=int)
@click.option("--scene_size", type=tuple, default=(64, 64))
@click.option("--radius_of_influence", default=100e3)
def cli(
input_sensor: Sensor,
target_sensor: Sensor,
Expand All @@ -597,7 +723,6 @@ def cli(
output_path: Path,
n_processes: int,
scene_size: Tuple[int, int] = (64, 64),
radius_of_influence: float = 100e3
) -> None:
"""
Extract pretraining data for SATFORMER training.
Expand Down Expand Up @@ -645,14 +770,13 @@ def cli(
end_time,
output_path=output_path,
scene_size=scene_size,
radius_of_influence=radius_of_influence
)
else:
pool = ProcessPoolExecutor(max_workers=n_processes)
tasks = []
for day in days:
start_time = datetime(year, month, day)
end_time = datetime(year, month, day + 1)
end_time = datetime(year, month, day)
tasks.append(
pool.submit(
extract_samples,
Expand All @@ -662,7 +786,6 @@ def cli(
end_time,
output_path=output_path,
scene_size=scene_size,
radius_of_influence=radius_of_influence
)
)

Expand Down

0 comments on commit 9d7e9fe

Please sign in to comment.