Skip to content

Commit

Permalink
Refactoring training datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Mar 18, 2024
1 parent afb5da2 commit 9152bed
Show file tree
Hide file tree
Showing 49 changed files with 991 additions and 7,483 deletions.
2 changes: 1 addition & 1 deletion gprof_nn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class DataConfig(ConfigBase):
"""
era5_path : Path = Path("/qdata2/archive/ERA5")
model_path : Path = Path(user_data_dir("gprof_nn", "gprof_nn")) / "models"
mrms_path : Path = Path("/pdata4/mrms/")
mrms_path : Path = Path("/pdata4/veljko/")

def print(self):
txt = "[data]\n"
Expand Down
33 changes: 25 additions & 8 deletions gprof_nn/data/era5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from gprof_nn.config import CONFIG
from gprof_nn import sensors
from gprof_nn.definitions import DATA_SPLIT
from gprof_nn.data.l1c import L1CFile
from gprof_nn.data.preprocessor import run_preprocessor
from gprof_nn.logging import get_console, log_messages
Expand Down Expand Up @@ -174,25 +175,23 @@ def process_l1c_file(

# Drop unneeded variables.
drop = ["sunglint_angle", "quality_flag", "wet_bulb_temperature", "lapse_rate"]
if not isinstance(sensor, sensors.CrossTrackScanner):
drop.append("earth_incidence_angle")
data_pp = data_pp.drop_vars(drop)

start_time = data_pp["scan_time"].data[0]
end_time = data_pp["scan_time"].data[-1]
era5_data = load_era5_data(start_time, end_time)
add_era5_precip(data_pp, era5_data)

data_pp.attrs["source"] = 2
data_pp.attrs["source"] = "era5"
if output_path_1d is not None:
write_training_samples_1d(
output_path_1d,
"mrms",
data_pp,
)
if output_path_3d is not None:
n_pixels = data_pp.pixels.size
n_scans = max(n_pixels, 128)
n_pixels = 64
n_scans = 128
write_training_samples_3d(
output_path_3d,
"mrms",
Expand All @@ -212,6 +211,7 @@ def process_l1c_files(
end_time: np.datetime64,
output_path_1d: Optional[Path] = None,
output_path_3d: Optional[Path] = None,
split: str = None,
n_processes: int = 4,
log_queue: Optional[Queue] = None
):
Expand All @@ -226,6 +226,8 @@ def process_l1c_files(
the training samples for the GPROF-NN 1D retrieval.
output_path_3d: Path pointing to the folder to which to write
the training samples for the GPROF-NN 3D retrieval.
split: An optional string specifying whether to extract only data from
one of the three data splits ['training', 'validation', 'test'].
n_processes: The number of processes to use for parallel
processing.
log_queue: Queue to use for logging from sub-processes.
Expand All @@ -241,10 +243,25 @@ def process_l1c_files(

LOGGER.info("Looking for files in %s.", l1c_path)
while time < end_time:
l1c_files += L1CFile.find_files(time, l1c_path)

l1c_files_day = L1CFile.find_files(time, l1c_path, sensor=sensor)
# Check if day of month should be skipped.
if split is not None:
days = DATA_SPLIT[split]
l1c_files_split = []
for l1c_file in l1c_files:
time = L1CFile(l1c_file).start_time
day_of_month = int(
(time - time.astype("datetime64[M]")).astype("timedelta64[D]").astype("int64")
)
if day_of_month + 1 in days:
l1c_files_split.append(l1c_file)
l1c_files_day = l1c_files_split

l1c_files += l1c_files_day
time += np.timedelta64(24 * 60 * 60, "s")

LOGGER.info("Found %s L1C fiels to process", len(l1c_files))
LOGGER.info("Found %s L1C files to process", len(l1c_files))


pool = ProcessPoolExecutor(max_workers=n_processes)
Expand All @@ -263,7 +280,7 @@ def process_l1c_files(

with Progress(console=get_console()) as progress:
pbar = progress.add_task(
"Extracting pretraining data:",
"Extracting ERA5 collocations:",
total=len(tasks)
)
for task in as_completed(tasks):
Expand Down
31 changes: 26 additions & 5 deletions gprof_nn/data/l1c.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def find_files(cls, date, path, roi=None, sensor=sensors.GMI):
month = date.month
day = date.day
data_path = Path(path) / f"{year:02}{month:02}" / f"{year:02}{month:02}{day:02}"
print("DATA PATH :: ", data_path)
print("PATTERN :: ", sensor.l1c_file_prefix + f"*{date.year:04}{month:02}{day:02}*{sensor.l1c_version}.HDF5")
files = list(
data_path.glob(
sensor.l1c_file_prefix + f"*{date.year:04}{month:02}{day:02}*{sensor.l1c_version}.HDF5"
Expand Down Expand Up @@ -478,23 +480,45 @@ def to_xarray_dataset(self, roi=None):

# Handle case that observations are split up.
tbs = []
eia = []
tbs.append(input[f"{swath}/Tc"][:][indices])
eia_s = input[f"{swath}/incidenceAngle"][:][indices]
eia_s = np.broadcast_to(eia_s, tbs[-1].shape)
eia.append(eia_s)
if "S2" in input.keys():
tbs.append(input["S2/Tc"][:][indices])
eia_s = input[f"S2/incidenceAngle"][:][indices]
eia_s = np.broadcast_to(eia_s, tbs[-1].shape)
eia.append(eia_s)
if "S3" in input.keys():
tbs.append(input["S3/Tc"][:][indices])
eia_s = input[f"s2/incidenceangle"][:][indices]
eia_s = np.broadcast_to(eia_s, tbs[-1].shape)
eia.append(eia_s)
if "S4" in input.keys():
tbs.append(input["S4/Tc"][:][indices])
eia_s = input[f"s2/incidenceangle"][:][indices]
eia_s = np.broadcast_to(eia_s, tbs[-1].shape)
eia.append(eia_s)
if "S5" in input.keys():
tbs_s = input["S5/Tc"][:][indices]
eia_s = input[f"s2/incidenceangle"][:][indices]
if tbs_s.shape[-2] > tbs[-1].shape[-2]:
tbs_s = tbs_s[..., ::2, :]
eia_s = eia_s[..., ::2]
tbs.append(tbs_s)
eia_s = np.broadcast_to(eia_s, tbs[-1].shape)
eia.append(eia_s)
if "S6" in input.keys():
tbs_s = input["S6/Tc"][:][indices]
eia_s = input[f"s2/incidenceangle"][:][indices]
if tbs_s.shape[-2] > tbs[-1].shape[-2]:
tbs_s = tbs_s[..., ::2, :]
eia_s = eia_s[..., ::2]
tbs.append(tbs_s)
eia_s = input[f"s2/incidenceangle"][:][indices]
eia_s = np.broadcast_to(eia_s, tbs[-1].shape)
eia.append(eia_s)

n_pixels = max([array.shape[1] for array in tbs])
tbs_r = []
Expand Down Expand Up @@ -546,11 +570,8 @@ def to_xarray_dataset(self, roi=None):
"scan_time": (dims[:1], times),
}

if "incidenceAngle" in input[f"{swath}"].keys():
data["incidence_angle"] = (
dims,
input[f"{swath}/incidenceAngle"][indices, :, 0],
)
eia = np.concatenate(eia, axis=-1)
data["incidence_angle"] = (dims + ("channels",), eia)

if "SCorientation" in input[f"{swath}/SCstatus"]:
data["sensor_orientation"] = (
Expand Down
37 changes: 30 additions & 7 deletions gprof_nn/data/mrms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import click
import numpy as np
import xarray as xr
import pandas as pd
from pyresample import geometry, kd_tree
from pykdtree.kdtree import KDTree
from rich.progress import Progress
from scipy.signal import convolve

from gprof_nn import sensors
from gprof_nn.logging import get_console, log_messages
from gprof_nn.coordinates import latlon_to_ecef
from gprof_nn.definitions import DATA_SPLIT
from gprof_nn.logging import get_console, log_messages
from gprof_nn.data.validation import unify_grid, calculate_angles
from gprof_nn.data.preprocessor import run_preprocessor
from gprof_nn.data.l1c import L1CFile
Expand Down Expand Up @@ -269,7 +269,7 @@ def extract_collocations(

# Match targets
match_file.match_targets(data_pp)
data_pp.attrs["source"] = 1
data_pp.attrs["source"] = "mrms"

if output_path_1d is not None:
write_training_samples_1d(
Expand All @@ -279,8 +279,8 @@ def extract_collocations(
)

if output_path_3d is not None:
n_pixels = data_pp.pixels.size
n_scans = max(n_pixels, 128)
n_pixels = 64
n_scans = 128
write_training_samples_3d(
output_path_3d,
"mrms",
Expand All @@ -299,7 +299,8 @@ def process_match_file(
l1c_path: Path,
output_path_1d: Optional[Path] = None,
output_path_3d: Optional[Path] = None,
n_processes: int = 4
n_processes: int = 4,
split: Optional[str] = None
):
"""
Process a single MRMS match-up file.
Expand All @@ -311,21 +312,40 @@ def process_match_file(
l1c_file: Path object pointing to the L1C file to collocate
with the match ups.
output_path_1d: Path pointing to the folder to which to write
the GPROF-NN 1D training data.
the
output_path_3d: Path pointing to the folder to which to write
the GPROF-NN 3D training data.
n_processes: The number of processes to use for the data
extraction.
split: An optional string 'train', 'validation', 'test' specifying
which split of the dataset to extract.
"""
match_file = Path(match_file)
year_month = match_file.name[:4]
l1c_path = Path(l1c_path)
l1c_files = (l1c_path / year_month).glob(
f"**/{sensor.l1c_file_prefix}*.HDF5"
)
l1c_files = sorted(list(l1c_files))

if split is not None:
l1c_files_split = []
days = DATA_SPLIT[split]
for l1c_file in l1c_files:
time = L1CFile(l1c_file).start_time
day_of_month = int(
(time - time.astype("datetime64[M]")).astype("timedelta64[D]").astype("int64")
)
if day_of_month + 1 in days:
l1c_files_split.append(l1c_file)
l1c_files = l1c_files_split

LOGGER.info(
f"Found {len(l1c_files)} L1C files matching MRMS match-up file "
f"{match_file}."
)


pool = ProcessPoolExecutor(max_workers=n_processes)
tasks = []
for l1c_file in l1c_files:
Expand Down Expand Up @@ -369,6 +389,7 @@ def process_match_files(
l1c_path: Path,
output_path_1d: Path,
output_path_3d: Path,
split: Optional[str] = None,
n_processes: int = 4
):
"""
Expand All @@ -391,6 +412,7 @@ def process_match_files(
l1c_path,
output_path_1d,
output_path_3d,
split=split,
n_processes=n_processes
)

Expand Down Expand Up @@ -455,6 +477,7 @@ def cli(
l1c_path,
output_path_1d,
output_path_3d,
split=split,
n_processes=n_processes
)

Expand Down
4 changes: 2 additions & 2 deletions gprof_nn/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,12 +726,12 @@ def has_preprocessor():
"AMSR2": "gprof2023pp_AMSR2_L1C",
"AMSRE": "gprof2021pp_AMSRE_L1C",
"ATMS": "gprof2021pp_ATMS_L1C",
("GMI", "MHS"): "gprof2021pp_GMI_MHS_L1C",
("GMI", "MHS"): "gprof2023pp_GMI_L1C",
("GMI", "TMIPR"): "gprof2021pp_GMI_TMI_L1C",
("GMI", "TMIPO"): "gprof2021pp_GMI_TMI_L1C",
("GMI", "SSMI"): "gprof2021pp_GMI_SSMI_L1C",
("GMI", "SSMIS"): "gprof2021pp_GMI_SSMIS_L1C",
("GMI", "AMSR2"): "gprof2021pp_GMI_AMSR2_L1C",
("GMI", "AMSR2"): "gprof2023pp_GMI_L1C",
("GMI", "AMSRE"): "gprof2021pp_GMI_AMSRE_L1C",
("GMI", "ATMS"): "gprof2021pp_GMI_ATMS_L1C",
}
Expand Down
1 change: 0 additions & 1 deletion gprof_nn/data/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pathlib import Path

import numpy as np
from quantnn.normalizer import Normalizer
import xarray

from gprof_nn.definitions import MISSING
Expand Down
Loading

0 comments on commit 9152bed

Please sign in to comment.