Skip to content

Commit

Permalink
Added satformer training dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Oct 2, 2024
1 parent 47d1a6b commit 67ec68e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 24 deletions.
41 changes: 31 additions & 10 deletions gprof_nn/data/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@
}


CHANNEL_REGEXP = re.compile("([\d\.\+-]*)\s*GHz\s*(\w*)-Pol")
BEAM_WIDTHS = {
"gmi": [1.75, 1.75, 1.0, 1.0, 0.9, 0.9, 0.9, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
"atms": [5.2, 5.2, 2.2, 1.1, 1.1, 1.1, 1.1, 1.1],
"amsr2": [1.2, 1.2, 0.65, 0.65, 0.75, 0.75, 0.35, 0.35, 0.15, 0.15, 0.15, 0.15],
}


CHANNEL_REGEXP = re.compile("([\d\.\s\+\/-]*)\s*GHz\s*(\w*)-Pol")


SEM_A = 6_378_137.0
Expand Down Expand Up @@ -212,7 +219,6 @@ def calculate_obs_properties(
preprocessor_data: xr.Dataset,
granule: Granule,
radius_of_influence: float = 5e3,
beam_width: float = 1.0
) -> xr.Dataset:
"""
Extract observations and corresponding meta data from granule.
Expand All @@ -224,14 +230,16 @@ def calculate_obs_properties(
"""

lons = preprocessor_data.longitude.data
lats = preprocessor_data.latitude.data
swath = SwathDefinition(lats=lats, lons=lons)

observations = []
meta_data = []

l1c_file = L1CFile(granule.file_record.local_path)
sensor = l1c_file.sensor.name.lower()

granule_data = granule.open()
if "latitude" in granule_data:
pass
Expand All @@ -243,8 +251,10 @@ def calculate_obs_properties(
freqs = []
offsets = []
pols = []

for match in CHANNEL_REGEXP.findall(granule_data[f"tbs_s{swath_ind}"].attrs["LongName"]):
freq, pol = match
freq = freq.replace("/", "")
if freq.find("+-") > 0:
freq, offs = freq.split("+-")
freqs.append(float(freq))
Expand Down Expand Up @@ -292,7 +302,7 @@ def calculate_obs_properties(
freqs[chan_ind] * np.ones_like(observations[-1]),
offsets[chan_ind] * np.ones_like(observations[-1]),
pols[chan_ind] * np.ones_like(observations[-1]),
beam_width * np.ones_like(observations[-1]),
BEAM_WIDTHS[sensor][chan_ind] * np.ones_like(observations[-1]),
sensor_alt,
zenith,
azimuth
Expand Down Expand Up @@ -376,33 +386,44 @@ def extract_pretraining_scenes(


class InputLoader:
def __init__(self, inputs: List[Any]):
def __init__(self, inputs: List[Any], radius_of_influence: float = 100e3):
self.inputs = inputs
self.radius_of_influence = radius_of_influence

def __len__(self) -> int:
return len(self.inputs)

def __getitem__(self, index: int) -> int:
return self.load_data(index)

def load_data(self, ind: int) -> Tuple[Dict[str, torch.Tensor], str, xr.Dataset]:

input_granule, target_granules = self.inputs[ind]
target_granule = sorted(list(target_granules))[0]

input_data = run_preprocessor(input_granule)
input_obs = calculate_obs_properties(input_data, input_granule, radius_of_influence=radius_of_influence)
target_obs = calculate_obs_properties(input_data, target_granule, radius_of_influence=radius_of_influence)
input_obs = calculate_obs_properties(input_data, input_granule, radius_of_influence=self.radius_of_influence)
target_obs = calculate_obs_properties(input_data, target_granule, radius_of_influence=self.radius_of_influence)

training_data = xr.Dataset({
"latitude": input_data.latitude,
"longitude": input_data.longitude,
"input_observations": input_obs.observations.rename(channels="input_channels"),
"input_meta_data": input_obs.meta_data.rename(channels="input_channels"),
"target_observations": target_obs.observations.rename(channels="target_channels"),
"target_meta_data": target_obs.meta_data.rename(channels="target_channels"),
})
tbs = training_data.input_observations.data
tbs[tbs < 0] = np.nan
n_seq_in = tbs.shape[0]
tbs = training_data.target_observations.data
tbs[tbs < 0] = np.nan
n_seq_out = tbs.shape[0]

input_data = {
"input_observations": torch.tensor(training_data.input_observations)[None, None],
"input_meta": torch.tensor(training_data.input_meta_data)[None, None],
"output_meta": torch.tensor(training_data.target_meta_data)[None, None],
"input_observations": torch.tensor(training_data.input_observations.data)[None, None],
"input_meta": torch.tensor(training_data.input_meta_data.data)[None].transpose(1, 2),
"output_meta": torch.tensor(training_data.target_meta_data.data)[None].transpose(1, 2),
}

filename = "match_" + target_granule.time_range.start.strftime("%Y%m%d%H%M%s") + ".nc"
Expand Down
46 changes: 32 additions & 14 deletions gprof_nn/data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,17 +1594,23 @@ def __getitem__(self, ind):


class SatformerDataset:
"""
Dataset for training a Satformer to produce simulated brightness temperatures.
"""
def __init__(
self,
path: Path,
seq_len_in: int = 13,
seq_len_out: int = 4,
validation: bool = False
validation: bool = False,
channel_dropout: float = 0.1
):
self.input_files = np.array(sorted(list(Path(path).glob("*.nc"))))
self.drop_inputs = 1
self.seq_len_in = seq_len_in
self.seq_len_out = seq_len_out
self.validation = validation
self.channel_dropout = channel_dropout
self.init_rng()

def init_rng(self, w_id=0):
Expand Down Expand Up @@ -1639,27 +1645,35 @@ def __getitem__(self, ind: int):
except Exception:
return self[self.rng.integers(0, len(self))]


n_chans_in = data.input_channels.size
n_chans_out = data.target_channels.size
chans_in = self.rng.permutation(n_chans_in)
chans_out = self.rng.permutation(n_chans_out)

input_observations = data.input_observations.data.astype("float32")
input_meta = data.input_meta_data.data.astype("float32")
dropped_observations = data.target_observations.data.astype("float32")
dropped_meta = data.target_meta_data.data.astype("float32")
target_observations = data.target_observations.data.astype("float32")
target_meta = data.target_meta_data.data.astype("float32")

obs_in = []
meta_in = []
obs_dropped = []
meta_dropped = []

for input_ind in range(self.seq_len_in):
if input_ind < len(chans_in):
obs_in.append(torch.tensor(input_observations[[chans_in[input_ind]]]))
meta_in.append(torch.tensor(input_meta[chans_in[input_ind]]))
if input_ind < self.drop_inputs:
obs_dropped.append(torch.tensor(input_observations[[chans_in[input_ind]]]))
meta_dropped.append(torch.tensor(input_meta[chans_in[input_ind]]))
else:
obs_in.append(torch.nan * torch.zeros_like(obs_in[-1]))
meta_in.append(torch.nan * torch.zeros_like(meta_in[-1]))
rand = self.rng.random()
if (rand > self.channel_dropout) and input_ind < len(chans_in):
obs_in.append(torch.tensor(input_observations[[chans_in[input_ind]]]))
meta_in.append(torch.tensor(input_meta[chans_in[input_ind]]))
else:
obs_in.append(torch.nan * torch.zeros_like(torch.tensor(input_observations[:1])))
meta_in.append(torch.nan * torch.zeros_like(torch.tensor(input_meta[0])))

obs_out = []
meta_out = []
Expand All @@ -1672,14 +1686,18 @@ def __getitem__(self, ind: int):
meta_out.append(torch.nan * torch.zeros_like(meta_out[-1]))

inpt = {
"input_observations": torch.stack(obs_in, 1),
"input_meta": torch.stack(meta_in, 1),
"output_meta": torch.stack(meta_out, 1),
"observations": torch.stack(obs_in, 1),
"input_observation_props": torch.stack(meta_in, 1),
"output_observation_props": torch.stack(meta_out, 1),
"dropped_observation_props": torch.stack(meta_dropped, 1),
}
mask = torch.isnan(inpt["input_observations"]).all(0).all(-1).all(-1)
inpt["mask"] = mask
target = torch.stack(obs_out, 1)
mask = torch.isnan(inpt["observations"]).all(0).all(-1).all(-1)
inpt["input_observation_mask"] = mask

target = {
"dropped_observations": obs_dropped,
"output_observations": obs_out,
}
data.close()

return inpt, obs_out
return inpt, target

0 comments on commit 67ec68e

Please sign in to comment.