Skip to content

Commit

Permalink
Fix nodeweights
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Dec 11, 2024
1 parent 2df18a7 commit 44f070e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,8 @@ def get_scatter_frame(
vmin=vmin,
vmax=vmax,
)
ax.set_xlim((-np.pi, np.pi))
ax.set_ylim((-np.pi / 2, np.pi / 2))
# ax.set_xlim((-np.pi, np.pi))
# ax.set_ylim((-np.pi / 2, np.pi / 2))
continents.plot_continents(ax)
ax.set_aspect("auto", adjustable=None)
_hide_axes_ticks(ax)
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/training/losses/nodeweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def area_weights(self, graph_data: HeteroData) -> torch.Tensor:
torch.Tensor
area weights of the target nodes
"""
return AreaWeights(norm="unit-max", fill_value=0).compute(graph_data, self.target)
return AreaWeights(flat=True, norm="unit-max").compute(graph_data, self.target)

def weights(self, graph_data: HeteroData) -> torch.Tensor:
"""Returns weight of type self.node_attribute for nodes self.target.
Expand Down

0 comments on commit 44f070e

Please sign in to comment.