diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 0fd407f..cd71791 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -1,3 +1,5 @@ +v0.1.1 + - Improved `plot.pairplot` adding `xr.Dataset` input support. v0.1.0 - Added `io` module for jsons. v0.0.5 diff --git a/CITATION.cff b/CITATION.cff index 2c0174a..9785c8f 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,7 +5,7 @@ authors: given-names: Simon Donald Alistair orcid: https://orcid.org/0000-0001-7911-1659 title: Sithom's Scientific Python Utilities Package -version: v0.1.0 +version: v0.1.1 doi: 10.5281/zenodo.7020109 url: "https://github.com/sdat2/sithom" -date-released: 2024-09-25 +date-released: 2024-10-08 diff --git a/sithom/_version.py b/sithom/_version.py index ac7dabc..05b3d67 100644 --- a/sithom/_version.py +++ b/sithom/_version.py @@ -1,6 +1,6 @@ """Module to store the __version__ string etc.""" -__version__ = "0.1.0" +__version__ = "0.1.1" __project__ = "Simon's utility scripts" __copyright__ = "2024, Simon Thomas" __author__ = "Simon Thomas" diff --git a/sithom/plot.py b/sithom/plot.py index ebf8d59..6354187 100644 --- a/sithom/plot.py +++ b/sithom/plot.py @@ -41,7 +41,7 @@ """ -from typing import Sequence, Tuple, Optional, Literal, List +from typing import Sequence, Tuple, Optional, Literal, List, Union import itertools from shutil import which import numpy as np @@ -466,17 +466,53 @@ def lim( return (float(vmin), float(vmax)) -def pairplot(df: pd.DataFrame) -> None: +def _pairplot_ds(ds: xr.Dataset, vars: Optional[List[str]] = False, label: bool = False) -> None: + """_pairplot_ds for xarray Dataset. + + Args: + ds (xr.Dataset): Dataset to plot. + vars (Optional[List[str]], optional): Variables to plot. Defaults to False. + label (bool, optional): Whether to label the subplots. Defaults to False. + """ + vars = vars if vars else list(ds.data_vars) + ds = ds[vars] + rn_dict = {} + + for var in ds: + if "long_name" in ds[var].attrs: + rn_dict[var] = ds[var].attrs["long_name"] + else: + rn_dict[var] = var + if "units" in ds[var].attrs: + rn_dict[var] += " [" + ds[var].attrs["units"] + "]" + + df = ds.rename(rn_dict).to_dataframe()[list(rn_dict.values())] + pairplot(df, label=label) + + +def pairplot(inp: Union[xr.Dataset, pd.DataFrame], + vars: Optional[List[str]] = None, + label: bool = False) -> None: """ Improved seaborn pairplot from: https://stackoverflow.com/a/50835066 - TODO: Add option for subplot labels (a), (b), (c) etc. + TODO: Improve option for subplot labels (a), (b), (c) etc. Args: - df (pd.DataFrame): A data frame. + inp (Union[xr.Dataset, pd.DataFrame]): A dataset or dataframe to plot. + vars (Optional[List[str]], optional): Variables to plot. Defaults to None. + label (bool, optional): Whether to label the subplots. Defaults to False. """ + if isinstance(inp, xr.Dataset): + return _pairplot_ds(inp, vars=vars, label=label) + elif isinstance(inp, pd.DataFrame): + df = inp + else: + raise ValueError("Input must be a pandas DataFrame or xarray Dataset.") + + ax_list = [] def corrfunc(x, y, ax=None, **kws) -> None: """Plot the correlation coefficient in the @@ -487,11 +523,21 @@ def corrfunc(x, y, ax=None, **kws) -> None: corr = ma.corrcoef(ma.masked_invalid(x), ma.masked_invalid(y)) corr_coeff = corr[0, 1] ax = ax or plt.gca() - ax.annotate(f"ρ = {corr_coeff:.2f}", xy=(0.05, 1.0), xycoords=ax.transAxes) + ax.annotate(f"ρ = {corr_coeff:.2f}", xy=(0.35, 1.0), xycoords=ax.transAxes) g = sns.pairplot(df, corner=True) g.map_lower(corrfunc) + def get_ax(x, y, ax=None, **kws) -> None: + nonlocal ax_list + ax = ax or plt.gca() + ax_list.append(ax) + + if label: + g.map_lower(get_ax) + # g.map_diag(get_ax) + label_subplots(ax_list, start_from=0, fontsize=10, x_pos=0.06, y_pos=1.03)#, override="outside") + def feature_grid( ds: xr.Dataset,