Skip to content

Commit

Permalink
Tests GPROFNN1DDataset for AMSR2.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Mar 28, 2024
1 parent 98d4296 commit 2f2ab0b
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion tests/data/test_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def test_load_ancillary_data_mhs(training_files_1d_mhs_sim):
"training_files_1d_mhs_mrms",
"training_files_1d_mhs_era5"
])
def test_load_training_data_1d_mhs(training_files, request):
def test_gprof_nn_1d_dataset_mhs(training_files, request):

training_files = request.getfixturevalue(training_files)

Expand Down Expand Up @@ -519,6 +519,35 @@ def test_load_targets_amsr2_sim(training_files_1d_amsr2_sim):
assert "surface_precip" in targets
assert np.isfinite(targets["surface_precip"].numpy()).all()

@pytest.mark.parametrize(
"training_files_1d",
[
"training_files_1d_amsr2_sim",
"training_files_1d_amsr2_mrms",
"training_files_1d_amsr2_era5",
])
def test_gprof_nn_1d_dataset_amsr2(training_files_1d, request):
"""
Ensure that the GPROFNN3DDataset correctly loads the GPROF-NN 3D training data.
"""
training_files = request.getfixturevalue(training_files_1d)
training_data = GPROFNN1DDataset(training_files[0].parent)

x, y = training_data[0]
assert "brightness_temperatures" in x
tbs = x["brightness_temperatures"]
assert tbs.ndim == 3
assert (tbs[torch.isfinite(tbs)] > 0).all()
assert "viewing_angles" in x
assert x["viewing_angles"].ndim == 3
assert "ancillary_data" in x
assert x["ancillary_data"].ndim == 3

assert "surface_precip" in y
sp = y["surface_precip"]
assert sp.ndim == 2
assert (sp[torch.isfinite(sp)] >= 0.0).all()


def test_load_training_data_3d_conical_sim(training_files_3d_amsr2_sim):
"""
Expand Down

0 comments on commit 2f2ab0b

Please sign in to comment.