Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hatchmap and scattermap fixes #195

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ New features and enhancements
* ``fg.taylordiagram`` can now accept datasets with many dimensions (not only `taylor_params`), provided that they all share the same `ref_std` (e.g. normalized taylor diagrams) (:pull:`214`).
* A new optional way to organize points in a ``fg.taylordiagram`` with `colors_key`, `markers_key` : DataArrays with a common dimension value or a common attribute are grouped with the same color/marker (:pull:`214`).
* Heatmap (``fg.matplotlib.heatmap``) now supports `row,col` arguments in `plot_kw`, allowing to plot a grid of heatmaps. (:issue:`208`, :pull:`219`).
* No-legend option in ``hatchmap``; use ``edgecolor`` and ``edgecolors`` as aliases (:pull:`195`)

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/figanos_multiplots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
"im = fg.hatchmap({'sup_305k': sup_305k, 'inf_300k': inf_300k},\n",
" plot_kw={\n",
" 'sup_305k': {\n",
" 'hatches': '*',\n",
" 'hatches': ['////'], # hatches must be passed as a list\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use can instead of must

" 'col': 'time',\n",
" \"x\": \"lon\",\n",
" \"y\": \"lat\"\n",
Expand All @@ -199,7 +199,7 @@
" enumerate_subplots=True, \n",
" )\n",
"\n",
"im.fig.suptitle(\"Multiple hatchmaps\", y=1.08)\n"
"im.fig.suptitle(\"Multiple hatchmaps\", y=1.08)"
]
},
{
Expand Down
129 changes: 71 additions & 58 deletions src/figanos/matplotlib/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,17 +1617,15 @@ def scattermap(
if "row" not in plot_kw and "col" not in plot_kw:
use_attrs.setdefault("title", "description")

plot_kw_pop = copy.deepcopy(plot_kw) # copy plot_kw to modify and pop info in it

# extract plot_kw from dict if needed
if isinstance(data, dict) and plot_kw and list(data.keys())[0] in plot_kw.keys():
plot_kw_pop = plot_kw_pop[list(data.keys())[0]]
plot_kw = plot_kw[list(data.keys())[0]]

# figanos does not use xr.plot.scatter default markersize
if "markersize" in plot_kw.keys():
if not sizes:
sizes = plot_kw["markersize"]
plot_kw_pop.pop("markersize")
plot_kw.pop("markersize")

# if data is dict, extract
if isinstance(data, dict):
Expand Down Expand Up @@ -1667,13 +1665,13 @@ def scattermap(
elif ax is not None and ("col" in plot_kw or "row" in plot_kw):
raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
elif ax is None:
plot_kw_pop = {"subplot_kws": {"projection": projection}} | plot_kw_pop
plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw
cfig_kw = fig_kw.copy()
if "figsize" in fig_kw: # add figsize to plot_kw for facetgrid
plot_kw_pop.setdefault("figsize", fig_kw["figsize"])
plot_kw.setdefault("figsize", fig_kw["figsize"])
cfig_kw.pop("figsize")
if len(cfig_kw) >= 1:
plot_kw_pop = {"subplot_kws": {"projection": projection}} | plot_kw_pop
plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw
warnings.warn(
"Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid."
)
Expand All @@ -1693,9 +1691,9 @@ def scattermap(
cbar_label = get_attributes(use_attrs["cbar_label"], data)

if "add_colorbar" not in plot_kw or plot_kw["add_colorbar"] is not False:
plot_kw_pop.setdefault("cbar_kwargs", {})
plot_kw_pop["cbar_kwargs"].setdefault("label", wrap_text(cbar_label))
plot_kw_pop["cbar_kwargs"].setdefault("pad", 0.015)
plot_kw.setdefault("cbar_kwargs", {})
plot_kw["cbar_kwargs"].setdefault("label", wrap_text(cbar_label))
plot_kw["cbar_kwargs"].setdefault("pad", 0.015)

# colormap
if isinstance(cmap, str):
Expand Down Expand Up @@ -1748,12 +1746,15 @@ def scattermap(
target_range=size_range,
data_range=None,
)
plot_kw_pop.setdefault("add_legend", False)
plot_kw.setdefault("add_legend", False)
if ax:
plot_kw_pop.setdefault("s", pt_sizes)
plot_kw.setdefault("s", pt_sizes)
else:
plot_kw_pop.setdefault("s", pt_sizes[0])
plot_kw.setdefault("s", pt_sizes[0])

# norm
plot_kw.setdefault("vmin", np.nanmin(plot_data.values[mask]))
plot_kw.setdefault("vmax", np.nanmax(plot_data.values[mask]))
if levels is not None:
if isinstance(levels, Iterable):
lin = levels
Expand All @@ -1766,7 +1767,7 @@ def scattermap(
divergent=divergent,
linspace_out=True,
)
plot_kw_pop.setdefault("levels", lin)
plot_kw.setdefault("levels", lin)

elif (divergent is not False) and ("levels" not in plot_kw):
norm = custom_cmap_norm(
Expand All @@ -1776,42 +1777,51 @@ def scattermap(
levels=levels,
divergent=divergent,
)
plot_kw_pop.setdefault("norm", norm)
plot_kw.setdefault("norm", norm)

# matplotlib.pyplot.scatter treats "edgecolor" and "edgecolors" as aliases so we accept "edgecolor" and convert it
if "edgecolor" in plot_kw and "edgecolors" not in plot_kw:
plot_kw["edgecolors"] = plot_kw["edgecolor"]
plot_kw.pop("edgecolor")

# set defaults and
plot_kw_pop = {
# set defaults and create copy without vmin, vmax (conflicts with norm)
plot_kw = {
"cmap": cmap,
"transform": transform,
"zorder": 8,
"marker": "o",
} | plot_kw_pop
"edgecolors": "none",
} | plot_kw

# chek if edgecolors in plot_kw and match len of plot_data
if "edgecolors" in plot_kw:
if matplotlib.colors.is_color_like(plot_kw["edgecolors"]):
plot_kw_pop["edgecolors"] = np.repeat(
plot_kw["edgecolors"] = np.repeat(
plot_kw["edgecolors"], len(plot_data.where(mask).values)
)
elif len(plot_kw["edgecolors"]) != len(plot_data.values):
plot_kw_pop["edgecolors"] = np.repeat(
plot_kw["edgecolors"] = np.repeat(
plot_kw["edgecolors"][0], len(plot_data.where(mask).values)
)
warnings.warn(
"Length of edgecolors does not match length of data. Only first edgecolor is used for plotting."
)
else:
if isinstance(plot_kw["edgecolors"], list):
plot_kw_pop["edgecolors"] = np.array(plot_kw["edgecolors"])
plot_kw_pop["edgecolors"] = plot_kw_pop["edgecolors"][mask]
plot_kw["edgecolors"] = np.array(plot_kw["edgecolors"])
plot_kw["edgecolors"] = plot_kw["edgecolors"][mask]
else:
plot_kw_pop.setdefault("edgecolor", "none")
plot_kw.setdefault("edgecolors", "none")

for key in ["vmin", "vmax"]:
plot_kw.pop(key)
# plot
plot_kw_pop = {"x": "lon", "y": "lat", "hue": plot_data.name} | plot_kw_pop
plot_kw = {"x": "lon", "y": "lat", "hue": plot_data.name} | plot_kw
if ax:
plot_kw_pop.setdefault("ax", ax)
v = plot_data.where(mask).to_dataset()
im = v.plot.scatter(**plot_kw_pop)
plot_kw.setdefault("ax", ax)

plot_data_masked = plot_data.where(mask).to_dataset()
im = plot_data_masked.plot.scatter(**plot_kw)

# add features
if ax:
Expand Down Expand Up @@ -1874,7 +1884,7 @@ def scattermap(
np.resize(sdata.values[mask], (sdata.values[mask].size, 1)),
np.resize(pt_sizes[mask], (pt_sizes[mask].size, 1)),
max_entries=6,
marker=plot_kw_pop["marker"],
marker=plot_kw["marker"],
)
# legend spacing
if size_range[1] > 200:
Expand Down Expand Up @@ -2281,7 +2291,7 @@ def hatchmap(
features: list[str] | dict[str, dict[str, Any]] | None = None,
geometries_kw: dict[str, Any] | None = None,
levels: int | None = None,
legend_kw: dict[str, Any] | None = None,
legend_kw: dict[str, Any] | bool = True,
show_time: bool | str | int | tuple[float, float] = False,
frame: bool = False,
enumerate_subplots: bool = False,
Expand Down Expand Up @@ -2313,8 +2323,8 @@ def hatchmap(
cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
geometries_kw : dict, optional
Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
legend_kw : dict, optional
Arguments to pass to `ax.legend()`.
legend_kw : dict or boolean, optional
Arguments to pass to `ax.legend()`. No legend is added if legend_kw == False.
show_time : bool, tuple, string or int.
If True, show time (as date) at the bottom right of the figure.
Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location
Expand Down Expand Up @@ -2377,24 +2387,23 @@ def hatchmap(

dattrs = None
plot_data = {}
dc = plot_kw.copy()

# convert data to dict (if not one)
if not isinstance(data, dict):
if isinstance(data, xr.DataArray):
plot_data = {data.name: data}
if list(data.keys())[0] not in plot_kw.keys():
plot_kw = {list(plot_data.keys())[0]: dc}
if data.name not in plot_kw.keys():
plot_kw = {data.name: plot_kw}
elif isinstance(data, xr.Dataset):
dattrs = data
plot_data = {var: data[var] for var in data.data_vars}
for v in plot_data.keys():
if v not in plot_kw.keys():
plot_kw[v] = dc
plot_kw[v] = plot_kw
else:
for k, v in data.items():
if k not in plot_kw.keys():
plot_kw[k] = dc
plot_kw[k] = plot_kw
if isinstance(v, xr.Dataset):
dattrs = k
plot_data[k] = v[list(v.data_vars)[0]]
Expand All @@ -2416,28 +2425,25 @@ def hatchmap(
if transform and (
"xlim" in list(plot_kw.values())[0] and "ylim" in list(plot_kw.values())[0]
):
extend = [
extent = [
list(plot_kw.values())[0]["xlim"][0],
list(plot_kw.values())[0]["xlim"][1],
list(plot_kw.values())[0]["ylim"][0],
list(plot_kw.values())[0]["ylim"][1],
]
{v.pop("xlim") for v in plot_kw.values()}
{v.pop("ylim") for v in plot_kw.values()}
[v.pop(lim) for lim in ["xlim", "ylim"] for v in plot_kw.values() if lim in v]

elif transform and (
"xlim" in list(plot_kw.values())[0] or "ylim" in list(plot_kw.values())[0]
):
extend = None
extent = None
warnings.warn(
"Requires both xlim and ylim with 'transform'. Xlim or ylim was dropped"
)
if "xlim" in list(plot_kw.values())[0].keys():
{v.pop("xlim") for v in plot_kw.values()}
if "ylim" in list(plot_kw.values())[0].keys():
{v.pop("ylim") for v in plot_kw.values()}
[v.pop(lim) for lim in ["xlim", "ylim"] for v in plot_kw.values() if lim in v]

else:
extend = None
extent = None

# setup fig, ax
if ax is None and (
Expand All @@ -2451,11 +2457,11 @@ def hatchmap(
):
raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
elif ax is None:
{
[
v.setdefault("subplot_kws", {}).setdefault("projection", projection)
for v in plot_kw.values()
}
cfig_kw = fig_kw.copy()
]
cfig_kw = copy.deepcopy(fig_kw)
if "figsize" in fig_kw: # add figsize to plot_kw for facetgrid
plot_kw[0].setdefault("figsize", fig_kw["figsize"])
cfig_kw.pop("figsize")
Expand Down Expand Up @@ -2503,9 +2509,9 @@ def hatchmap(
im = v.where(mask is not True).plot.contourf(**plot_kw[k])
artists, labels = im.legend_elements(str_format="{:2.1f}".format)

if ax:
if ax and legend_kw:
ax.legend(artists, labels, **legend_kw)
else:
elif legend_kw:
im.figlegend = im.fig.legend(**legend_kw)

elif len(plot_data) > 1 and "levels" in plot_kw[k]:
Expand All @@ -2519,6 +2525,13 @@ def hatchmap(
if "hatches" not in plot_kw[k].keys():
plot_kw[k]["hatches"] = dfh[n]
n += 1
elif isinstance(
plot_kw[k]["hatches"], str
): # make sure the hatches are in a list
warnings.warn(
"Hatches argument must be of type 'list'. Wrapping string argument as list."
)
plot_kw[k]["hatches"] = [plot_kw[k]["hatches"]]

plot_kw[k].setdefault("transform", transform)
if ax:
Expand Down Expand Up @@ -2552,31 +2565,31 @@ def hatchmap(
geometries_kw,
frame,
)
if extend:
fax.set_extent(extend)
if extent:
fax.set_extent(extent)

pat_leg.append(
matplotlib.patches.Patch(
hatch=plot_kw[k]["hatches"], fill=False, label=k
hatch=plot_kw[k]["hatches"][0], fill=False, label=k
)
)

if pat_leg:
if pat_leg and legend_kw:
legend_kw = {
"loc": "lower right",
"handleheight": 2,
"handlelength": 4,
} | legend_kw

if ax:
if ax and legend_kw:
ax.legend(handles=pat_leg, **legend_kw)
else:
elif legend_kw:
im.figlegend = im.fig.legend(handles=pat_leg, **legend_kw)

# add features
if ax:
if extend:
ax.set_extend(extend)
if extent:
ax.set_extent(extent)
if dattrs:
use_attrs.setdefault("title", "description")

Expand Down
Loading