Skip to content

Commit

Permalink
Fix xr parser issues (#612)
Browse files Browse the repository at this point in the history
* fix issues with attr comparison functions arising from scalar coords with multiple values (e.g., plev19)

* check that tv_name is in exclusion varlist before trying to remove it in the preprocessor.process case_list loop
add support for multi-value scalar_coords to ConvertUnitsFunction
  • Loading branch information
wrongkindofdoctor authored Jul 11, 2024
1 parent 635c246 commit 82b7cfb
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 32 deletions.
13 changes: 11 additions & 2 deletions src/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,16 @@ def execute(self, var, ds, **kwargs):
ds, c.name, src_unit=None, dest_unit=dest_c.units,
log=var.log
)
c.value = None
if len(ds[c.name]) > 1:
for v in ds[c.name].values:
if int(v) / dest_c.value == 100: # v = dest_c in Pa
c.value = dest_c.value
elif int(v) == dest_c.value:
c.value = v
else:
c.value = ds[c.name].item()
c.units = dest_c.units
c.value = ds[c.name].item()

var.log.info("Converted units on %s.", var.full_name)
return ds
Expand Down Expand Up @@ -1298,7 +1306,8 @@ def process(self,
tv_name = v.translation.name
var_xr_dataset = self.parse_ds(v, case_xr_dataset)
varlist_ex = [v_l.translation.name for v_l in case_list[case_name].varlist.iter_vars()]
varlist_ex.remove(tv_name)
if tv_name in varlist_ex:
varlist_ex.remove(tv_name)
for v_d in var_xr_dataset.variables:
if v_d not in varlist_ex:
cat_subset[case_name].update({v_d: var_xr_dataset[v_d]})
Expand Down
71 changes: 41 additions & 30 deletions src/xr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@ def compare_attr(self, our_attr_tuple, ds_attr_tuple, comparison_func=None,
fill_ours = (fill_ours and ds_attr is not ATTR_NOT_FOUND)
if overwrite_ours is not None:
overwrite_ours = (overwrite_ours and ds_attr is not ATTR_NOT_FOUND)

self.overwrite_ds = (ds_attr is ATTR_NOT_FOUND)
if self.overwrite_ds:
# user set CLI option to force overwrite of ds from our_var
fill_ds = True
Expand Down Expand Up @@ -810,7 +812,14 @@ def compare_attr(self, our_attr_tuple, ds_attr_tuple, comparison_func=None,
if fill_ours:
# update our attr with value from ds, but also raise error
setattr(our_var, our_attr_name, ds_attr)
comparison_func = self.approximate_attribute_value(our_attr, ds_attr)
if isinstance(our_attr, tuple) and isinstance(ds_attr, tuple):
our_attr_str = [i for i in our_attr if isinstance(i, str)][0]
ds_attr_str = [i for i in ds_attr if isinstance(i, str)][0]
comparison_func = True
if not our_attr_str == ds_attr_str:
comparison_func = self.approximate_attribute_value(our_attr_str, ds_attr_str)
else:
comparison_func = self.approximate_attribute_value(our_attr, ds_attr)
if not comparison_func:
raise util.MetadataEvent((f"Unexpected {our_attr_name} for variable "
f"'{our_var.name}': '{ds_attr}' (expected '{our_attr}')."))
Expand Down Expand Up @@ -861,12 +870,13 @@ def reconcile_attr(self, our_var, ds_var, our_attr_name, ds_attr_name=None,
ds_attr_name = our_attr_name
our_attr = getattr(our_var, our_attr_name)
ds_attr = ds_var.attrs.get(ds_attr_name, ATTR_NOT_FOUND)

self.compare_attr(
(our_var, our_attr_name, our_attr), (ds_var, ds_attr_name, ds_attr),
**kwargs
)

def reconcile_names(self, our_var, ds, ds_var_name, overwrite_ours=None):
def reconcile_names(self, our_var, ds, ds_var_name: str, overwrite_ours=None):
"""Reconcile the name and standard_name attributes between the
'ground truth' of the dataset we downloaded (*ds_var_name*) and our
expectations based on the model's convention (*our_var*).
Expand All @@ -880,11 +890,17 @@ def reconcile_names(self, our_var, ds, ds_var_name, overwrite_ours=None):
overwrite_ours (bool, default False): If True, always update the name
of *our_var* to what's found in *ds*.
"""
# Determine if a variable with the expected name is present at all
if ds_var_name not in ds:
if self.guess_names:

if ds_var_name not in ds.variables:
# try searching for 4-D field instead of variable name for a specific level
# (e.g., U instead of U500)
tv_name = ds_var_name
if len(our_var.scalar_coords) > 0:
ds_var_name = ''.join(filter(lambda x: not x.isdigit(), tv_name))
overwrite_ours = True
else:
# attempt to match on standard_name attribute if present in data
ds_names = [v.name for v in ds.variables \
ds_names = [v.name for v in ds.variables
if v.attrs.get('standard_name', "") == our_var.standard_name]
if len(ds_names) == 1:
# success, narrowed down to one guess
Expand All @@ -898,10 +914,6 @@ def reconcile_names(self, our_var, ds, ds_var_name, overwrite_ours=None):
# failure
raise util.MetadataError(f"Variable name '{ds_var_name}' not "
f"found in dataset: ({list(ds.variables)}).")
else:
# not guessing; error out
raise util.MetadataError(f"Variable name '{ds_var_name}' not found "
f"in dataset: ({list(ds.variables)}).")

# in all non-error cases: now that variable has been identified in ds,
# straightforward to compare attrs
Expand Down Expand Up @@ -990,26 +1002,25 @@ def _compare_value_and_units(our_var, ds_var, comparison_func=None):
# "attribute" to compare is tuple of (numerical value, units string),
# which is converted to unit-ful object by src.units.to_cfunits()
our_attr = (our_var.value, our_var.units)
ds_attr = (float(ds_var), ds_var.attrs.get('units', ATTR_NOT_FOUND))
try:
self.compare_attr(
(our_var, attr_name, our_attr), (ds_var, attr_name, ds_attr),
comparison_func=comparison_func,
fill_ours=True, fill_ds=False
)
finally:
# cleanup placeholder attr if our_var was altered
if hasattr(our_var, attr_name):
our_var.value, new_units = getattr(our_var, attr_name)
our_var.units = units.to_cfunits(new_units)
self.log.debug("Updated (value, units) of '%s' to (%s, %s).",
our_var.name, our_var.value, our_var.units)
delattr(our_var, attr_name)

assert (hasattr(our_var, 'is_scalar') and our_var.is_scalar), \
for i in ds_var.values:
ds_attr = (float(i), ds_var.attrs.get('units', ATTR_NOT_FOUND))
try:
self.compare_attr(
(our_var, attr_name, our_attr), (ds_var, attr_name, ds_attr),
comparison_func=comparison_func,
fill_ours=True, fill_ds=False
)
finally:
# cleanup placeholder attr if our_var was altered
if hasattr(our_var, attr_name):
our_var.value, new_units = getattr(our_var, attr_name)
our_var.units = units.to_cfunits(new_units)
self.log.debug("Updated (value, units) of '%s' to (%s, %s).",
our_var.name, our_var.value, our_var.units)
delattr(our_var, attr_name)

assert (hasattr(our_var, 'is_scalar') and our_var.is_scalar),\
self.log.error('is_scalar att missing and/or is_scalar is false for ', our_var)
assert ds_var.size == 1, \
self.log.error('size neq 1 for ', our_var)
# Check equivalence of units: if units inequivalent, raises MetadataEvent
try:
_compare_value_and_units(
Expand Down Expand Up @@ -1169,7 +1180,7 @@ def reconcile_variable(self, var, ds):
coordinates in *translated_var* (our expectation, based on the DataSource's
naming convention) with attributes actually present in the Dataset *ds*.
"""
tv = var.translation # abbreviate
tv = var.translation
# check name, std_name, units on variable itself
self.reconcile_names(tv, ds, tv.name, overwrite_ours=None)
self.reconcile_units(tv, ds[tv.name])
Expand Down

0 comments on commit 82b7cfb

Please sign in to comment.