diff --git a/gprof_nn/data/pretraining.py b/gprof_nn/data/pretraining.py index 7f18b8f..2683542 100644 --- a/gprof_nn/data/pretraining.py +++ b/gprof_nn/data/pretraining.py @@ -65,7 +65,14 @@ } -CHANNEL_REGEXP = re.compile("([\d\.\+-]*)\s*GHz\s*(\w*)-Pol") +BEAM_WIDTHS = { + "gmi": [1.75, 1.75, 1.0, 1.0, 0.9, 0.9, 0.9, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4], + "atms": [5.2, 5.2, 2.2, 1.1, 1.1, 1.1, 1.1, 1.1], + "amsr2": [1.2, 1.2, 0.65, 0.65, 0.75, 0.75, 0.35, 0.35, 0.15, 0.15, 0.15, 0.15], +} + + +CHANNEL_REGEXP = re.compile("([\d\.\s\+\/-]*)\s*GHz\s*(\w*)-Pol") SEM_A = 6_378_137.0 @@ -212,7 +219,6 @@ def calculate_obs_properties( preprocessor_data: xr.Dataset, granule: Granule, radius_of_influence: float = 5e3, - beam_width: float = 1.0 ) -> xr.Dataset: """ Extract observations and corresponding meta data from granule. @@ -224,7 +230,6 @@ def calculate_obs_properties( """ - lons = preprocessor_data.longitude.data lats = preprocessor_data.latitude.data swath = SwathDefinition(lats=lats, lons=lons) @@ -232,6 +237,9 @@ def calculate_obs_properties( observations = [] meta_data = [] + l1c_file = L1CFile(granule.file_record.local_path) + sensor = l1c_file.sensor.name.lower() + granule_data = granule.open() if "latitude" in granule_data: pass @@ -243,8 +251,10 @@ def calculate_obs_properties( freqs = [] offsets = [] pols = [] + for match in CHANNEL_REGEXP.findall(granule_data[f"tbs_s{swath_ind}"].attrs["LongName"]): freq, pol = match + freq = freq.replace("/", "") if freq.find("+-") > 0: freq, offs = freq.split("+-") freqs.append(float(freq)) @@ -292,7 +302,7 @@ def calculate_obs_properties( freqs[chan_ind] * np.ones_like(observations[-1]), offsets[chan_ind] * np.ones_like(observations[-1]), pols[chan_ind] * np.ones_like(observations[-1]), - beam_width * np.ones_like(observations[-1]), + BEAM_WIDTHS[sensor][chan_ind] * np.ones_like(observations[-1]), sensor_alt, zenith, azimuth @@ -376,8 +386,15 @@ def extract_pretraining_scenes( class InputLoader: - def __init__(self, inputs: List[Any]): + def __init__(self, inputs: List[Any], radius_of_influence: float = 100e3): self.inputs = inputs + self.radius_of_influence = radius_of_influence + + def __len__(self) -> int: + return len(self.inputs) + + def __getitem__(self, index: int) -> int: + return self.load_data(index) def load_data(self, ind: int) -> Tuple[Dict[str, torch.Tensor], str, xr.Dataset]: @@ -385,10 +402,12 @@ def load_data(self, ind: int) -> Tuple[Dict[str, torch.Tensor], str, xr.Dataset] target_granule = sorted(list(target_granules))[0] input_data = run_preprocessor(input_granule) - input_obs = calculate_obs_properties(input_data, input_granule, radius_of_influence=radius_of_influence) - target_obs = calculate_obs_properties(input_data, target_granule, radius_of_influence=radius_of_influence) + input_obs = calculate_obs_properties(input_data, input_granule, radius_of_influence=self.radius_of_influence) + target_obs = calculate_obs_properties(input_data, target_granule, radius_of_influence=self.radius_of_influence) training_data = xr.Dataset({ + "latitude": input_data.latitude, + "longitude": input_data.longitude, "input_observations": input_obs.observations.rename(channels="input_channels"), "input_meta_data": input_obs.meta_data.rename(channels="input_channels"), "target_observations": target_obs.observations.rename(channels="target_channels"), @@ -396,13 +415,15 @@ def load_data(self, ind: int) -> Tuple[Dict[str, torch.Tensor], str, xr.Dataset] }) tbs = training_data.input_observations.data tbs[tbs < 0] = np.nan + n_seq_in = tbs.shape[0] tbs = training_data.target_observations.data tbs[tbs < 0] = np.nan + n_seq_out = tbs.shape[0] input_data = { - "input_observations": torch.tensor(training_data.input_observations)[None, None], - "input_meta": torch.tensor(training_data.input_meta_data)[None, None], - "output_meta": torch.tensor(training_data.target_meta_data)[None, None], + "input_observations": torch.tensor(training_data.input_observations.data)[None, None], + "input_meta": torch.tensor(training_data.input_meta_data.data)[None].transpose(1, 2), + "output_meta": torch.tensor(training_data.target_meta_data.data)[None].transpose(1, 2), } filename = "match_" + target_granule.time_range.start.strftime("%Y%m%d%H%M%s") + ".nc" diff --git a/gprof_nn/data/training_data.py b/gprof_nn/data/training_data.py index da76631..7ce5405 100644 --- a/gprof_nn/data/training_data.py +++ b/gprof_nn/data/training_data.py @@ -1594,17 +1594,23 @@ def __getitem__(self, ind): class SatformerDataset: + """ + Dataset for training a Satformer to produce simulated brightness temperatures. + """ def __init__( self, path: Path, seq_len_in: int = 13, seq_len_out: int = 4, - validation: bool = False + validation: bool = False, + channel_dropout: float = 0.1 ): self.input_files = np.array(sorted(list(Path(path).glob("*.nc")))) + self.drop_inputs = 1 self.seq_len_in = seq_len_in self.seq_len_out = seq_len_out self.validation = validation + self.channel_dropout = channel_dropout self.init_rng() def init_rng(self, w_id=0): @@ -1639,7 +1645,6 @@ def __getitem__(self, ind: int): except Exception: return self[self.rng.integers(0, len(self))] - n_chans_in = data.input_channels.size n_chans_out = data.target_channels.size chans_in = self.rng.permutation(n_chans_in) @@ -1647,19 +1652,28 @@ def __getitem__(self, ind: int): input_observations = data.input_observations.data.astype("float32") input_meta = data.input_meta_data.data.astype("float32") + dropped_observations = data.target_observations.data.astype("float32") + dropped_meta = data.target_meta_data.data.astype("float32") target_observations = data.target_observations.data.astype("float32") target_meta = data.target_meta_data.data.astype("float32") obs_in = [] meta_in = [] + obs_dropped = [] + meta_dropped = [] for input_ind in range(self.seq_len_in): - if input_ind < len(chans_in): - obs_in.append(torch.tensor(input_observations[[chans_in[input_ind]]])) - meta_in.append(torch.tensor(input_meta[chans_in[input_ind]])) + if input_ind < self.drop_inputs: + obs_dropped.append(torch.tensor(input_observations[[chans_in[input_ind]]])) + meta_dropped.append(torch.tensor(input_meta[chans_in[input_ind]])) else: - obs_in.append(torch.nan * torch.zeros_like(obs_in[-1])) - meta_in.append(torch.nan * torch.zeros_like(meta_in[-1])) + rand = self.rng.random() + if (rand > self.channel_dropout) and input_ind < len(chans_in): + obs_in.append(torch.tensor(input_observations[[chans_in[input_ind]]])) + meta_in.append(torch.tensor(input_meta[chans_in[input_ind]])) + else: + obs_in.append(torch.nan * torch.zeros_like(torch.tensor(input_observations[:1]))) + meta_in.append(torch.nan * torch.zeros_like(torch.tensor(input_meta[0]))) obs_out = [] meta_out = [] @@ -1672,14 +1686,18 @@ def __getitem__(self, ind: int): meta_out.append(torch.nan * torch.zeros_like(meta_out[-1])) inpt = { - "input_observations": torch.stack(obs_in, 1), - "input_meta": torch.stack(meta_in, 1), - "output_meta": torch.stack(meta_out, 1), + "observations": torch.stack(obs_in, 1), + "input_observation_props": torch.stack(meta_in, 1), + "output_observation_props": torch.stack(meta_out, 1), + "dropped_observation_props": torch.stack(meta_dropped, 1), } - mask = torch.isnan(inpt["input_observations"]).all(0).all(-1).all(-1) - inpt["mask"] = mask - target = torch.stack(obs_out, 1) + mask = torch.isnan(inpt["observations"]).all(0).all(-1).all(-1) + inpt["input_observation_mask"] = mask + target = { + "dropped_observations": obs_dropped, + "output_observations": obs_out, + } data.close() - return inpt, obs_out + return inpt, target