Skip to content

Commit

Permalink
Fix unintended un-masking for detail tables
Browse files Browse the repository at this point in the history
  • Loading branch information
mkstratos committed Feb 27, 2023
1 parent 2e8b340 commit efb4e4e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
5 changes: 3 additions & 2 deletions evv4esm/ensembles/e3sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,12 @@ def load_mpas_climatology_ensemble(files, field_name, mask_value=None):
var_desc = "{long_name}{units}".format(**get_variable_meta(dset, field_name))

dims = _field.shape
ens_out = np.zeros([*dims, len(files)])
ens_out = np.ma.zeros([*dims, len(files)])
ens_out[..., 0] = _field
for idx, file_name in enumerate(files[1:]):
with Dataset(file_name, "r") as dset:
ens_out[..., idx + 1] = dset.variables[field_name][:].squeeze()
_field = dset.variables[field_name][:].squeeze()
ens_out[..., idx + 1] = _field

if mask_value:
ens_out = np.ma.masked_less(ens_out, mask_value)
Expand Down
12 changes: 9 additions & 3 deletions evv4esm/extensions/kso.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ def main(args):
# performing the test (across ensemble members) to be the last dimension
# (e.g. [nCells, nLevels, nEns]) this is why load_mpas_climatology_ensemble
# returns data in this way
ks_test = np.vectorize(stats.mstats.ks_2samp, signature="(n),(n)->(),()")
ks_test = np.vectorize(
stats.mstats.ks_2samp, signature="(n),(n)->(),()", excluded=["method"]
)

images = {"accept": [], "reject": [], "-": []}
details = LIVVDict()
Expand All @@ -387,8 +389,13 @@ def main(args):

annuals_1 = var_1["data"]
annuals_2 = var_2["data"]
if isinstance(annuals_1, np.ma.MaskedArray) and isinstance(
annuals_2, np.ma.MaskedArray
):
_, p_val = ks_test(annuals_1.filled(), annuals_2.filled(), method="asymp")
else:
_, p_val = ks_test(annuals_1, annuals_2, method="asymp")

_, p_val = ks_test(annuals_1, annuals_2)
null_reject_pre_correct = np.sum(np.where(p_val <= args.alpha, 1, 0))
_, p_val = smm.fdrcorrection(
p_val.flatten(), alpha=args.alpha, method="indep", is_sorted=False
Expand Down Expand Up @@ -417,7 +424,6 @@ def main(args):
mask_value = -0.9999e33
annuals_1 = np.ma.masked_less(annuals_1, mask_value)
annuals_2 = np.ma.masked_less(annuals_2, mask_value)

details[var]["mean (test case, ref. case)"] = (
annuals_1.mean(),
annuals_2.mean(),
Expand Down

0 comments on commit efb4e4e

Please sign in to comment.