Skip to content

Commit

Permalink
improved pairplot
Browse files Browse the repository at this point in the history
  • Loading branch information
sdat2 committed Oct 8, 2024
1 parent 216434b commit e7b96a2
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
v0.1.1
- Improved `plot.pairplot` adding `xr.Dataset` input support.
v0.1.0
- Added `io` module for jsons.
v0.0.5
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors:
given-names: Simon Donald Alistair
orcid: https://orcid.org/0000-0001-7911-1659
title: Sithom's Scientific Python Utilities Package
version: v0.1.0
version: v0.1.1
doi: 10.5281/zenodo.7020109
url: "https://github.com/sdat2/sithom"
date-released: 2024-09-25
date-released: 2024-10-08
2 changes: 1 addition & 1 deletion sithom/_version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module to store the __version__ string etc."""

__version__ = "0.1.0"
__version__ = "0.1.1"
__project__ = "Simon's utility scripts"
__copyright__ = "2024, Simon Thomas"
__author__ = "Simon Thomas"
56 changes: 51 additions & 5 deletions sithom/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"""

from typing import Sequence, Tuple, Optional, Literal, List
from typing import Sequence, Tuple, Optional, Literal, List, Union
import itertools
from shutil import which
import numpy as np
Expand Down Expand Up @@ -466,17 +466,53 @@ def lim(
return (float(vmin), float(vmax))


def pairplot(df: pd.DataFrame) -> None:
def _pairplot_ds(ds: xr.Dataset, vars: Optional[List[str]] = False, label: bool = False) -> None:
"""_pairplot_ds for xarray Dataset.
Args:
ds (xr.Dataset): Dataset to plot.
vars (Optional[List[str]], optional): Variables to plot. Defaults to False.
label (bool, optional): Whether to label the subplots. Defaults to False.
"""
vars = vars if vars else list(ds.data_vars)
ds = ds[vars]
rn_dict = {}

for var in ds:
if "long_name" in ds[var].attrs:
rn_dict[var] = ds[var].attrs["long_name"]
else:
rn_dict[var] = var
if "units" in ds[var].attrs:
rn_dict[var] += " [" + ds[var].attrs["units"] + "]"

df = ds.rename(rn_dict).to_dataframe()[list(rn_dict.values())]
pairplot(df, label=label)


def pairplot(inp: Union[xr.Dataset, pd.DataFrame],
vars: Optional[List[str]] = None,
label: bool = False) -> None:
"""
Improved seaborn pairplot from:
https://stackoverflow.com/a/50835066
TODO: Add option for subplot labels (a), (b), (c) etc.
TODO: Improve option for subplot labels (a), (b), (c) etc.
Args:
df (pd.DataFrame): A data frame.
inp (Union[xr.Dataset, pd.DataFrame]): A dataset or dataframe to plot.
vars (Optional[List[str]], optional): Variables to plot. Defaults to None.
label (bool, optional): Whether to label the subplots. Defaults to False.
"""
if isinstance(inp, xr.Dataset):
return _pairplot_ds(inp, vars=vars, label=label)
elif isinstance(inp, pd.DataFrame):
df = inp
else:
raise ValueError("Input must be a pandas DataFrame or xarray Dataset.")

ax_list = []

def corrfunc(x, y, ax=None, **kws) -> None:
"""Plot the correlation coefficient in the
Expand All @@ -487,11 +523,21 @@ def corrfunc(x, y, ax=None, **kws) -> None:
corr = ma.corrcoef(ma.masked_invalid(x), ma.masked_invalid(y))
corr_coeff = corr[0, 1]
ax = ax or plt.gca()
ax.annotate(f"ρ = {corr_coeff:.2f}", xy=(0.05, 1.0), xycoords=ax.transAxes)
ax.annotate(f"ρ = {corr_coeff:.2f}", xy=(0.35, 1.0), xycoords=ax.transAxes)

g = sns.pairplot(df, corner=True)
g.map_lower(corrfunc)

def get_ax(x, y, ax=None, **kws) -> None:
nonlocal ax_list
ax = ax or plt.gca()
ax_list.append(ax)

if label:
g.map_lower(get_ax)
# g.map_diag(get_ax)
label_subplots(ax_list, start_from=0, fontsize=10, x_pos=0.06, y_pos=1.03)#, override="outside")


def feature_grid(
ds: xr.Dataset,
Expand Down

0 comments on commit e7b96a2

Please sign in to comment.