Skip to content

Commit

Permalink
Small fixes to the callback plots
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Nov 21, 2024
1 parent 0e6c00a commit a414f48
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,25 +311,53 @@ def plot_histogram(
yt_xt = yt - xt
yp_xt = yp - xt
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt_xt), np.nanmin(yp_xt))
bin_max = max(np.nanmax(yt_xt), np.nanmax(yp_xt))
hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, range=[bin_min, bin_max])
bin_min = min(np.nanpercentile(yt_xt, 0.05), np.nanpercentile(yp_xt, 0.05))
bin_max = max(np.nanpercentile(yt_xt, 0.95), np.percentile(yp_xt, 0.95))
hist_yt, bins_yt = np.histogram(
yt_xt[~np.isnan(yt_xt)],
bins=100,
density=True,
range=[bin_min, bin_max],
)
hist_yp, bins_yp = np.histogram(
yp_xt[~np.isnan(yp_xt)],
bins=100,
density=True,
range=[bin_min, bin_max],
)
else:
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt), np.nanmin(yp))
bin_max = max(np.nanmax(yt), np.nanmax(yp))
hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, range=[bin_min, bin_max])
bin_min = min(np.nanpercentile(yt, 0.05), np.nanpercentile(yp, 0.05))
bin_max = max(np.nanpercentile(yt, 0.95), np.nanpercentile(yp, 0.95))
hist_yt, bins_yt = np.histogram(
yt[~np.isnan(yt)], bins=100, density=True, range=[bin_min, bin_max]
)
hist_yp, bins_yp = np.histogram(
yp[~np.isnan(yp)], bins=100, density=True, range=[bin_min, bin_max]
)

# Visualization trick for tp
if variable_name in precip_and_related_fields:
# in-place multiplication does not work here because variables are different numpy types
hist_yt = hist_yt * bins_yt[:-1]
hist_yp = hist_yp * bins_yp[:-1]
# Plot the modified histogram
ax[plot_idx].bar(bins_yt[:-1], hist_yt, width=np.diff(bins_yt), color="blue", alpha=0.7, label="Truth (data)")
ax[plot_idx].bar(bins_yp[:-1], hist_yp, width=np.diff(bins_yp), color="red", alpha=0.7, label="Predicted")
ax[plot_idx].bar(
bins_yt[:-1],
hist_yt,
width=np.diff(bins_yt),
color="blue",
alpha=0.7,
label="Truth (data)",
)
ax[plot_idx].bar(
bins_yp[:-1],
hist_yp,
width=np.diff(bins_yp),
color="red",
alpha=0.7,
label="Predicted",
)

ax[plot_idx].set_title(variable_name)
ax[plot_idx].set_xlabel(variable_name)
Expand Down

0 comments on commit a414f48

Please sign in to comment.