Skip to content

Commit

Permalink
Merge pull request #11 from LIVVkit/mkstratos/mvko
Browse files Browse the repository at this point in the history
Add ocean K-S test
  • Loading branch information
mkstratos authored Mar 8, 2023
2 parents 9b721cc + 92ab3e9 commit 15693ca
Show file tree
Hide file tree
Showing 6 changed files with 579 additions and 30 deletions.
2 changes: 1 addition & 1 deletion evv4esm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


__version_info__ = (0, 3, 2)
__version_info__ = (0, 4, 0)
__version__ = '.'.join(str(vi) for vi in __version_info__)

PASS_COLOR = '#389933'
Expand Down
89 changes: 64 additions & 25 deletions evv4esm/ensembles/e3sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,45 +37,71 @@
import glob

from collections import OrderedDict
from functools import partial

import numpy as np
import pandas as pd
from netCDF4 import Dataset


def component_file_instance(component, case_file):
search_regex = r'{c}_[0-9]+'.format(c=component)
search_regex = r"{c}_[0-9]+".format(c=component)
result = re.search(search_regex, case_file).group(0)
return int(result.replace('{}_'.format(component), ''))
return int(result.replace("{}_".format(component), ""))


def file_date_str(case_file, style='short'):
if style == 'full':
search_regex = r'h0\.[0-9]+-[0-9]+-[0-9]+-[0-9]+.nc'
elif style == 'short':
search_regex = r'h0\.[0-9]+-[0-9]+.nc'
def file_date_str(case_file, style="short", hist_name="h0"):
if style == "full":
search_regex = r"{}\.[0-9]+-[0-9]+-[0-9]+-[0-9]+.nc".format(hist_name)
elif style == "med":
search_regex = r"{}\.[0-9]+-[0-9]+-[0-9]+.nc".format(hist_name)
elif style == "short":
search_regex = r"{}\.[0-9]+-[0-9]+.nc".format(hist_name)
else:
# FIXME: log warning here
search_regex = r'h0\.[0-9]+-[0-9]+.nc'
search_regex = r"{}\.[0-9]+-[0-9]+.nc".format(hist_name)

result = re.search(search_regex, case_file).group(0)
return result.replace('h0.', '').replace('.nc', '')
return result.replace("{}.".format(hist_name), "").replace(".nc", "")


def component_monthly_files(dir_, component, ninst, hist_name="h0", nmonth_max=12, date_style="short"):
if date_style == "full":
date_search = "????-??-??-??"
elif date_style == "med":
date_search = "????-??-??"
else:
date_search = "????-??"

def component_monthly_files(dir_, component, ninst):
base = '{d}/*{c}_????.h0.????-??.nc'.format(d=dir_, c=component)
def component_monthly_files(dir_, component, ninst, hist_name="hist", nmonth_max=24, date_style="short"):
base = "{d}/*{c}_????.{n}.????-??-??.nc".format(d=dir_, c=component, n=hist_name)
search = os.path.normpath(base)
result = sorted(glob.glob(search))

instance_files = OrderedDict()
_file_date_str = partial(file_date_str, style=date_style, hist_name=hist_name)
for ii in range(1, ninst + 1):
instance_files[ii] = sorted(filter(lambda x: component_file_instance(component, x) == ii, result),
key=file_date_str)
if len(instance_files[ii]) > 12:
instance_files[ii] = instance_files[ii][-12:]
instance_files[ii] = sorted(
filter(lambda x: component_file_instance(component, x) == ii, result),
key=_file_date_str,
)
if len(instance_files[ii]) > nmonth_max:
instance_files[ii] = instance_files[ii][-nmonth_max:]

return instance_files


def get_variable_meta(dataset, var_name):
try:
_name = f": {dataset.variables[var_name].getncattr('long_name')}"
except AttributeError:
_name = ""
try:
_units = f" [{dataset.variables[var_name].getncattr('units')}]"
except AttributeError:
_units = ""
return {"long_name": _name, "units": _units}


def gather_monthly_averages(ensemble_files, variable_set=None):
monthly_avgs = []
for case, inst_dict in six.iteritems(ensemble_files):
Expand All @@ -101,16 +127,29 @@ def gather_monthly_averages(ensemble_files, variable_set=None):
continue
else:
m = np.mean(data.variables[var][0, ...])
try:
_name = f": {data.variables[var].getncattr('long_name')}"
except AttributeError:
_name = ""
try:
_units = f" [{data.variables[var].getncattr('units')}]"
except AttributeError:
_units = ""
desc = f"{_name}{_units}"

desc = "{long_name}{units}".format(**get_variable_meta(data, var))
monthly_avgs.append((case, var, '{:04}'.format(inst), date_str, m, desc))

monthly_avgs = pd.DataFrame(monthly_avgs, columns=('case', 'variable', 'instance', 'date', 'monthly_mean', 'desc'))
return monthly_avgs


def load_mpas_climatology_ensemble(files, field_name, mask_value=None):
# Get the first file to set up ensemble array output
with Dataset(files[0], "r") as dset:
_field = dset.variables[field_name][:].squeeze()
var_desc = "{long_name}{units}".format(**get_variable_meta(dset, field_name))

dims = _field.shape
ens_out = np.ma.zeros([*dims, len(files)])
ens_out[..., 0] = _field
for idx, file_name in enumerate(files[1:]):
with Dataset(file_name, "r") as dset:
_field = dset.variables[field_name][:].squeeze()
ens_out[..., idx + 1] = _field

if mask_value:
ens_out = np.ma.masked_less(ens_out, mask_value)

return {"data": ens_out, "desc": var_desc}
22 changes: 18 additions & 4 deletions evv4esm/ensembles/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,31 @@ def prob_plot(test, ref, n_q, img_file, test_name='Test', ref_name='Ref.',
# NOTE: Produce unity-based normalization of data for the Q-Q plots because
# matplotlib can't handle small absolute values or data ranges. See
# https://github.com/matplotlib/matplotlib/issues/6015
if not np.allclose(min_, max_, atol=np.finfo(max_).eps):
if not np.allclose(min_, max_, rtol=np.finfo(max_).eps):
norm1 = (ref - min_) / (max_ - min_)
norm2 = (test - min_) / (max_ - min_)

ax1.scatter(np.percentile(norm1, q), np.percentile(norm2, q),
color=pf_color_picker.get(pf, '#1F77B4'), zorder=2)
ax3.hist(norm1, bins=n_q, color=pf_color_picker.get(pf, '#1F77B4'))
ax4.hist(norm2, bins=n_q, color=pf_color_picker.get(pf, '#1F77B4'))
ax3.hist(norm1, bins=n_q, color=pf_color_picker.get(pf, '#1F77B4'), edgecolor="k")
ax4.hist(norm2, bins=n_q, color=pf_color_picker.get(pf, '#1F77B4'), edgecolor="k")

# Check if these distributions are wildly different. If they are, use different
# colours for the bottom axis? Otherwise set the scales to be the same [0, 1]
if abs(norm1.mean() - norm2.mean()) >= 0.5:
ax3.tick_params(axis="x", colors="C0")
ax3.spines["bottom"].set_color("C0")

ax4.tick_params(axis="x", colors="C1")
ax4.spines["bottom"].set_color("C1")
else:
ax3.set_xlim(tuple(norm_rng))
ax4.set_xlim(tuple(norm_rng))


# bin both series into equal bins and get cumulative counts for each bin
bnds = np.linspace(min_, max_, n_q)
if not np.allclose(bnds, bnds[0], atol=np.finfo(bnds[0]).eps):
if not np.allclose(bnds, bnds[0], rtol=np.finfo(bnds[0]).eps):
ppxb = pd.cut(ref, bnds)
ppyb = pd.cut(test, bnds)

Expand All @@ -124,6 +137,7 @@ def prob_plot(test, ref, n_q, img_file, test_name='Test', ref_name='Ref.',
ax2.scatter(ppyh.values, ppxh.values,
color=pf_color_picker.get(pf, '#1F77B4'), zorder=2)


plt.tight_layout()
plt.savefig(img_file, bbox_inches='tight')

Expand Down
Loading

0 comments on commit 15693ca

Please sign in to comment.