Skip to content

Commit

Permalink
Add alternate support (#604)
Browse files Browse the repository at this point in the history
* add function new_dict_wo_key to util/basic.py

* add boolean attribute is_alternate to VarlistEntry class
remove iter_shallow_alternates function from varlist_util.py
refactor alternates attribute so that it is a list of VarlistEntry classes instead of a nested list
add call to new_dict_wo_key to remove alternate VarlistEntry keys from Varlist dict after they have been appended as attributes to the variable VarlistEntry that they can be substituted for

* refine alternate check in preprocessor query_catalog to use translation info attached to alternates if avaiable
need to fix issues with errors thrown by metadata checks for cases where 4d field is present, but doesn't match query for specific level

* add varlist_util to commit

* add logic to sub long_name for standard_name if necessary to xr_parser check_ds_attrs
add check to sub 4D field name for variable level name in ds search to check_ds_attrs
  • Loading branch information
wrongkindofdoctor authored Jul 9, 2024
1 parent 315ccba commit f34b6f4
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 25 deletions.
11 changes: 9 additions & 2 deletions src/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def edit_request(self, v: varlist_util.VarlistEntry, **kwargs):

# add original 4D var defined in new_tv as an alternate TranslatedVarlistEntry
# to query if no entries on specified levels are found in the data catalog

v.alternates.append(new_tv)

return v
Expand Down Expand Up @@ -922,7 +923,13 @@ def query_catalog(self,
if any(var.alternates):
try_new_query = True
for a in var.alternates:
case_d.query.update({'variable_id': a.name})
if hasattr(a, 'translation'):
if a.translation is not None:
case_d.query.update({'variable_id': a.translation.name})
case_d.query.update({'standard_name': a.translation.standard_name})
else:
case_d.query.update({'variable_id': a.name})
case_d.query.update({'standard_name': a.standard_name})
if any(var.translation.scalar_coords):
found_z_entry = False
# check for vertical coordinate to determine if level extraction is needed
Expand Down Expand Up @@ -1275,7 +1282,7 @@ def process(self,
cat_subset = self.query_catalog(case_list, config.DATA_CATALOG)
for case_name, case_xr_dataset in cat_subset.items():
for v in case_list[case_name].varlist.iter_vars():
tv_name = v.translation.name #abbreviate
tv_name = v.translation.name
var_xr_dataset = self.parse_ds(v, case_xr_dataset)
cat_subset[case_name]['time'] = var_xr_dataset['time']
cat_subset[case_name].update({tv_name: var_xr_dataset[tv_name]})
Expand Down
2 changes: 1 addition & 1 deletion src/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ConsistentDict, WormDefaultDict, NameSpace, MDTFEnum,
sentinel_object_factory, MDTF_ID, deactivate, ObjectStatus,
is_iterable, to_iter, from_iter, remove_prefix, RegexDict,
remove_suffix, filter_kwargs, splice_into_list
remove_suffix, filter_kwargs, splice_into_list, new_dict_wo_key
)

from .logs import (
Expand Down
7 changes: 7 additions & 0 deletions src/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
import collections
import collections.abc
import copy
import enum
import itertools
import string
Expand Down Expand Up @@ -709,3 +710,9 @@ def get_all_matching_values(self, queries: list):
return (match for query in queries for match in self.get_matching_value(query))


# source: https://stackoverflow.com/questions/5844672/delete-an-element-from-a-dictionary
def new_dict_wo_key(dictionary: dict, key: str) -> dict:
"""Returns a **shallow** copy of the input dictionary without a key."""
_dict = copy.copy(dictionary)
_dict.pop(key, None)
return _dict
69 changes: 48 additions & 21 deletions src/varlist_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class VarlistEntry(VarlistEntryBase, util.MDTFObjectBase, data_model.DMVariable,
status: util.ObjectStatus = dc.field(default=util.ObjectStatus.NOTSET, compare=False)
name: str = util.MANDATORY
_parent: typing.Any = dc.field(default=util.MANDATORY, compare=False)
is_alternate: bool = False

def __post_init__(self, coords=None):
# set up log (VarlistEntryLoggerMixin)
Expand Down Expand Up @@ -215,11 +216,12 @@ def __post_init__(self, coords=None):
self.path_variable = self.name.upper() + _file_env_var_suffix
# self.alternates is either [] or a list of nonempty lists of VEs
if hasattr(self, 'alternates'):
if isinstance(self.alternates, list):
if any(self.alternates):
if not isinstance(self.alternates[0], list):
self.alternates = [self.alternates]
self.alternates = [vs for vs in self.alternates if vs]
if not isinstance(self.alternates, list):
self.alternates = [self.alternates]
else:
self.alternates = []
if self.requirement == VarlistEntryRequirement.ALTERNATE:
self.is_alternate = True
if hasattr(self, 'scalar_coords'):
self.scalar_coords = self.scalar_coords

Expand Down Expand Up @@ -332,10 +334,15 @@ def set_env_vars(self):
})

for ax, dim in self.dim_axes.items():
trans_dim = self.translation.axes[ax]
self.env_vars[dim.name + _coord_env_var_suffix] = trans_dim.name
if trans_dim.has_bounds:
self.env_vars[dim.name + _coord_bounds_env_var_suffix] = trans_dim.bounds
if self.translation is not None:
trans_dim = self.translation.axes[ax]
self.env_vars[dim.name + _coord_env_var_suffix] = trans_dim.name
if trans_dim.has_bounds:
self.env_vars[dim.name + _coord_bounds_env_var_suffix] = trans_dim.bounds
else:
self.env_vars[dim.name + _coord_env_var_suffix] = dim.name
if dim.has_bounds:
self.env_vars[dim.name + _coord_bounds_env_var_suffix] = dim.bounds

def iter_alternates(self):
"""Breadth-first traversal of "sets" of alternate VarlistEntries,
Expand Down Expand Up @@ -567,8 +574,30 @@ def setup_var(self,
# store but don't deactivate, because preprocessor.edit_request()
# may supply alternate variables
v.log.store_exception(chained_exc)
# set the VarlistEntry env_vars (required for backwards compatibility with first-gen PODs)

# set the VarlistEntry env_vars (required for backwards compatibility with first-gen PODs
v.set_env_vars()
# Translate alternate vars if necessary
for alt_v in v.alternates:
if alt_v.T is not None:
alt_v.change_coord(
'T',
new_class={
'self': VarlistTimeCoordinate,
'range': util.DateRange,
'frequency': util.DateFrequency
},
range=date_range,
calendar=util.NOTSET,
units=util.NOTSET
)
alt_v.dest_path = self.variable_dest_path(model_paths, case_name, alt_v)
trans_alt = translate.translate(alt_v, from_convention)
alt_v.translation = trans_alt
if trans_alt is None:
alt_v.info(f'Note: alternate variable {alt_v.full_name} not translated from {from_convention} to'
f' {to_convention}')
alt_v.set_env_vars()

def variable_dest_path(self,
model_paths: util.ModelDataPathManager,
Expand Down Expand Up @@ -615,13 +644,6 @@ def _pod_dimension_from_struct(name, dd, v_settings):
except Exception:
raise ValueError(f"Couldn't parse dimension entry for {name}: {dd}")

def _iter_shallow_alternates(var):
"""Iterator over all VarlistEntries referenced as alternates. Doesn't
traverse alternates of alternates, etc.
"""
for alt_vs in var.alternates:
yield from alt_vs

vlist_settings = util.coerce_to_dataclass(
parent.pod_data, VarlistSettings)
globals_d = vlist_settings.global_settings
Expand All @@ -634,12 +656,17 @@ def _iter_shallow_alternates(var):
}
for v in vlist_vars.values():
# validate & replace names of alt vars with references to VE objects
for altv_name in _iter_shallow_alternates(v):
for altv_name in v.alternates:
if altv_name not in vlist_vars:
raise ValueError((f"Unknown variable name {altv_name} listed "
f"in alternates for varlist entry {v.name}."))
linked_alts = []
for alts in v.alternates:
linked_alts.append([vlist_vars[v_name] for v_name in alts])
linked_alts = [vlist_vars[v_name] for v_name in v.alternates]
v.alternates = linked_alts
alt_vars = [k for k, v in vlist_vars.items() if v.is_alternate]
for a in alt_vars:
vlist_vars = util.new_dict_wo_key(vlist_vars, a)

# remove alternates from VarlistEntries since they are now attributes of variable
# VarlistEntry objects that they can be substituted for

return cls(contents=list(vlist_vars.values()))
18 changes: 17 additions & 1 deletion src/xr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,11 @@ def check_metadata(self, ds_var, *attr_names):
if attr not in ds_var.attrs:
if attr in ds_var.encoding:
ds_var.attrs[attr] = ds_var.encoding[attr]
# TODO: maybe move the following block to reconcile_attrs and refactor
elif attr == 'standard_name' and 'long_name' in ds_var.attrs:
ds_var.attrs[attr] = ds_var.attrs['long_name']
elif attr == 'standard_name' and 'long_name' in ds_var.encoding:
ds_var.attrs[attr] = ds_var.encoding['long_name']
else:
ds_var.attrs[attr] = ATTR_NOT_FOUND
if ds_var.attrs[attr] is ATTR_NOT_FOUND:
Expand All @@ -1240,7 +1245,18 @@ def check_ds_attrs(self, var, ds):
# Only check attributes on the dependent variable var_name and its
# coordinates.
tv_name = var.translation.name
names_to_check = [tv_name] + list(ds[tv_name].dims)
if tv_name in ds.variables:
names_to_check = [tv_name] + list(ds[tv_name].dims)
# try searching for 4-D field instead of variable name for a specific level
# (e.g., U instead of U500)
elif len(var.translation.scalar_coords) > 0:
var_basename = ''.join(filter(lambda x: not x.isdigit(), tv_name))
if var_basename in ds.variables:
names_to_check = [var_basename] + list(ds[var_basename].dims)
else:
raise util.MetadataError(f'Did not find variable key {tv_name} or {var_basename}'
f'in queried xarray dataset.')

for v_name in names_to_check:
try:
self.check_metadata(ds[v_name], 'standard_name', 'units')
Expand Down

0 comments on commit f34b6f4

Please sign in to comment.