Skip to content

Commit

Permalink
Fixes to callback plots (ecmwf#182)
Browse files Browse the repository at this point in the history
* Lower bound delta lat in power spectrum plot and align input color map for
precip plots
  • Loading branch information
OpheliaMiralles authored Dec 10, 2024
1 parent 2179a59 commit da26cb7
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Keep it human-readable, your future self will thank you!
- Not update NaN-weight-mask for loss function when using remapper and no imputer [#178](https://github.com/ecmwf/anemoi-training/pull/178)
- Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180)
- Remove saving of metadata to training checkpoint [#57](https://github.com/ecmwf/anemoi-training/pull/190)
- Fixes to callback plots [#182] (power spectrum large numpy array error + precip cmap for cases where precip is prognostic).

### Added
- Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/detailed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ callbacks:

- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum
# every_n_batches: 100 # Override for batch frequency
# min_delta: 0.01 # Minimum distance between two consecutive points
sample_idx: ${diagnostics.plot.sample_idx}
parameters:
- z_500
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,7 @@ def __init__(
config: OmegaConf,
sample_idx: int,
parameters: list[str],
min_delta: float | None = None,
every_n_batches: int | None = None,
) -> None:
"""Initialise the PlotSpectrum callback.
Expand All @@ -1036,6 +1037,7 @@ def __init__(
super().__init__(config, every_n_batches=every_n_batches)
self.sample_idx = sample_idx
self.parameters = parameters
self.min_delta = min_delta

@rank_zero_only
def _plot(
Expand Down Expand Up @@ -1070,6 +1072,7 @@ def _plot(
data[0, ...].squeeze(),
data[rollout_step + 1, ...].squeeze(),
output_tensor[rollout_step, ...],
min_delta=self.min_delta,
)

self._output_figure(
Expand Down
67 changes: 62 additions & 5 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def plot_power_spectrum(
x: np.ndarray,
y_true: np.ndarray,
y_pred: np.ndarray,
min_delta: float | None = None,
) -> Figure:
"""Plots power spectrum.
Expand All @@ -156,13 +157,16 @@ def plot_power_spectrum(
Expected data of shape (lat*lon, nvar*level)
y_pred : np.ndarray
Predicted data of shape (lat*lon, nvar*level)
min_delta: float, optional
Minimum distance between lat/lon points, if None defaulted to 1km
Returns
-------
Figure
The figure object handle.
"""
min_delta = min_delta or 0.0003
n_plots_x, n_plots_y = len(parameters), 1

figsize = (n_plots_y * 4, n_plots_x * 3)
Expand All @@ -177,9 +181,17 @@ def plot_power_spectrum(
# Calculate delta_lat on the projected grid
delta_lat = abs(np.diff(pc_lat))
non_zero_delta_lat = delta_lat[delta_lat != 0]
min_delta_lat = np.min(abs(non_zero_delta_lat))

if min_delta_lat < min_delta:
LOGGER.warning(
"Min. distance between lat/lon points is < specified minimum distance. Defaulting to min_delta=%s.",
min_delta,
)
min_delta_lat = min_delta

# Define a regular grid for interpolation
n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat))))
n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / min_delta_lat))
n_pix_lon = (n_pix_lat - 1) * 2 + 1 # 2*lmax + 1
regular_pc_lon = np.linspace(pc_lon.min(), pc_lon.max(), n_pix_lon)
regular_pc_lat = np.linspace(pc_lat.min(), pc_lat.max(), n_pix_lat)
Expand Down Expand Up @@ -313,14 +325,14 @@ def plot_histogram(
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt_xt), np.nanmin(yp_xt))
bin_max = max(np.nanmax(yt_xt), np.nanmax(yp_xt))
hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, range=[bin_min, bin_max])
hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, density=True, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, density=True, range=[bin_min, bin_max])
else:
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt), np.nanmin(yp))
bin_max = max(np.nanmax(yt), np.nanmax(yp))
hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, range=[bin_min, bin_max])
hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, density=True, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, density=True, range=[bin_min, bin_max])

# Visualization trick for tp
if variable_name in precip_and_related_fields:
Expand Down Expand Up @@ -623,6 +635,51 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
title=f"{vname} persist err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.",
datashader=datashader,
)
elif vname in precip_and_related_fields:
# Create a custom colormap for precipitation
nws_precip_colors = cmap_precip
precip_colormap = ListedColormap(nws_precip_colors)

# Defining the actual precipitation accumulation levels in mm
cummulation_lvls = clevels
norm = BoundaryNorm(cummulation_lvls, len(cummulation_lvls) + 1)

# converting to mm from m
input_ *= 1000.0
truth *= 1000.0
pred *= 1000.0
single_plot(
fig,
ax[0],
lon=lon,
lat=lat,
data=input_,
cmap=precip_colormap,
title=f"{vname} input",
datashader=datashader,
)
single_plot(
fig,
ax[4],
lon=lon,
lat=lat,
data=pred - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} increment [pred - input]",
datashader=datashader,
)
single_plot(
fig,
ax[5],
lon=lon,
lat=lat,
data=truth - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} persist err",
datashader=datashader,
)
else:
single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader)
single_plot(
Expand Down

0 comments on commit da26cb7

Please sign in to comment.