diff --git a/src/preprocessor.py b/src/preprocessor.py index 2bbd7721d..cbd45a154 100644 --- a/src/preprocessor.py +++ b/src/preprocessor.py @@ -557,6 +557,7 @@ def edit_request(self, v: varlist_util.VarlistEntry, **kwargs): # add original 4D var defined in new_tv as an alternate TranslatedVarlistEntry # to query if no entries on specified levels are found in the data catalog + v.alternates.append(new_tv) return v @@ -922,7 +923,13 @@ def query_catalog(self, if any(var.alternates): try_new_query = True for a in var.alternates: - case_d.query.update({'variable_id': a.name}) + if hasattr(a, 'translation'): + if a.translation is not None: + case_d.query.update({'variable_id': a.translation.name}) + case_d.query.update({'standard_name': a.translation.standard_name}) + else: + case_d.query.update({'variable_id': a.name}) + case_d.query.update({'standard_name': a.standard_name}) if any(var.translation.scalar_coords): found_z_entry = False # check for vertical coordinate to determine if level extraction is needed @@ -1275,7 +1282,7 @@ def process(self, cat_subset = self.query_catalog(case_list, config.DATA_CATALOG) for case_name, case_xr_dataset in cat_subset.items(): for v in case_list[case_name].varlist.iter_vars(): - tv_name = v.translation.name #abbreviate + tv_name = v.translation.name var_xr_dataset = self.parse_ds(v, case_xr_dataset) cat_subset[case_name]['time'] = var_xr_dataset['time'] cat_subset[case_name].update({tv_name: var_xr_dataset[tv_name]}) diff --git a/src/util/__init__.py b/src/util/__init__.py index c1a34a164..ce78b0760 100644 --- a/src/util/__init__.py +++ b/src/util/__init__.py @@ -6,7 +6,7 @@ ConsistentDict, WormDefaultDict, NameSpace, MDTFEnum, sentinel_object_factory, MDTF_ID, deactivate, ObjectStatus, is_iterable, to_iter, from_iter, remove_prefix, RegexDict, - remove_suffix, filter_kwargs, splice_into_list + remove_suffix, filter_kwargs, splice_into_list, new_dict_wo_key ) from .logs import ( diff --git a/src/util/basic.py b/src/util/basic.py index a79b3b615..f3c7c0416 100644 --- a/src/util/basic.py +++ b/src/util/basic.py @@ -3,6 +3,7 @@ import abc import collections import collections.abc +import copy import enum import itertools import string @@ -709,3 +710,9 @@ def get_all_matching_values(self, queries: list): return (match for query in queries for match in self.get_matching_value(query)) +# source: https://stackoverflow.com/questions/5844672/delete-an-element-from-a-dictionary +def new_dict_wo_key(dictionary: dict, key: str) -> dict: + """Returns a **shallow** copy of the input dictionary without a key.""" + _dict = copy.copy(dictionary) + _dict.pop(key, None) + return _dict diff --git a/src/varlist_util.py b/src/varlist_util.py index a999b2f47..ab64fc5d9 100644 --- a/src/varlist_util.py +++ b/src/varlist_util.py @@ -188,6 +188,7 @@ class VarlistEntry(VarlistEntryBase, util.MDTFObjectBase, data_model.DMVariable, status: util.ObjectStatus = dc.field(default=util.ObjectStatus.NOTSET, compare=False) name: str = util.MANDATORY _parent: typing.Any = dc.field(default=util.MANDATORY, compare=False) + is_alternate: bool = False def __post_init__(self, coords=None): # set up log (VarlistEntryLoggerMixin) @@ -215,11 +216,12 @@ def __post_init__(self, coords=None): self.path_variable = self.name.upper() + _file_env_var_suffix # self.alternates is either [] or a list of nonempty lists of VEs if hasattr(self, 'alternates'): - if isinstance(self.alternates, list): - if any(self.alternates): - if not isinstance(self.alternates[0], list): - self.alternates = [self.alternates] - self.alternates = [vs for vs in self.alternates if vs] + if not isinstance(self.alternates, list): + self.alternates = [self.alternates] + else: + self.alternates = [] + if self.requirement == VarlistEntryRequirement.ALTERNATE: + self.is_alternate = True if hasattr(self, 'scalar_coords'): self.scalar_coords = self.scalar_coords @@ -332,10 +334,15 @@ def set_env_vars(self): }) for ax, dim in self.dim_axes.items(): - trans_dim = self.translation.axes[ax] - self.env_vars[dim.name + _coord_env_var_suffix] = trans_dim.name - if trans_dim.has_bounds: - self.env_vars[dim.name + _coord_bounds_env_var_suffix] = trans_dim.bounds + if self.translation is not None: + trans_dim = self.translation.axes[ax] + self.env_vars[dim.name + _coord_env_var_suffix] = trans_dim.name + if trans_dim.has_bounds: + self.env_vars[dim.name + _coord_bounds_env_var_suffix] = trans_dim.bounds + else: + self.env_vars[dim.name + _coord_env_var_suffix] = dim.name + if dim.has_bounds: + self.env_vars[dim.name + _coord_bounds_env_var_suffix] = dim.bounds def iter_alternates(self): """Breadth-first traversal of "sets" of alternate VarlistEntries, @@ -567,8 +574,30 @@ def setup_var(self, # store but don't deactivate, because preprocessor.edit_request() # may supply alternate variables v.log.store_exception(chained_exc) - # set the VarlistEntry env_vars (required for backwards compatibility with first-gen PODs) + + # set the VarlistEntry env_vars (required for backwards compatibility with first-gen PODs v.set_env_vars() + # Translate alternate vars if necessary + for alt_v in v.alternates: + if alt_v.T is not None: + alt_v.change_coord( + 'T', + new_class={ + 'self': VarlistTimeCoordinate, + 'range': util.DateRange, + 'frequency': util.DateFrequency + }, + range=date_range, + calendar=util.NOTSET, + units=util.NOTSET + ) + alt_v.dest_path = self.variable_dest_path(model_paths, case_name, alt_v) + trans_alt = translate.translate(alt_v, from_convention) + alt_v.translation = trans_alt + if trans_alt is None: + alt_v.info(f'Note: alternate variable {alt_v.full_name} not translated from {from_convention} to' + f' {to_convention}') + alt_v.set_env_vars() def variable_dest_path(self, model_paths: util.ModelDataPathManager, @@ -615,13 +644,6 @@ def _pod_dimension_from_struct(name, dd, v_settings): except Exception: raise ValueError(f"Couldn't parse dimension entry for {name}: {dd}") - def _iter_shallow_alternates(var): - """Iterator over all VarlistEntries referenced as alternates. Doesn't - traverse alternates of alternates, etc. - """ - for alt_vs in var.alternates: - yield from alt_vs - vlist_settings = util.coerce_to_dataclass( parent.pod_data, VarlistSettings) globals_d = vlist_settings.global_settings @@ -634,12 +656,17 @@ def _iter_shallow_alternates(var): } for v in vlist_vars.values(): # validate & replace names of alt vars with references to VE objects - for altv_name in _iter_shallow_alternates(v): + for altv_name in v.alternates: if altv_name not in vlist_vars: raise ValueError((f"Unknown variable name {altv_name} listed " f"in alternates for varlist entry {v.name}.")) - linked_alts = [] - for alts in v.alternates: - linked_alts.append([vlist_vars[v_name] for v_name in alts]) + linked_alts = [vlist_vars[v_name] for v_name in v.alternates] v.alternates = linked_alts + alt_vars = [k for k, v in vlist_vars.items() if v.is_alternate] + for a in alt_vars: + vlist_vars = util.new_dict_wo_key(vlist_vars, a) + + # remove alternates from VarlistEntries since they are now attributes of variable + # VarlistEntry objects that they can be substituted for + return cls(contents=list(vlist_vars.values())) diff --git a/src/xr_parser.py b/src/xr_parser.py index 3781141b9..0253285fa 100644 --- a/src/xr_parser.py +++ b/src/xr_parser.py @@ -1218,6 +1218,11 @@ def check_metadata(self, ds_var, *attr_names): if attr not in ds_var.attrs: if attr in ds_var.encoding: ds_var.attrs[attr] = ds_var.encoding[attr] + # TODO: maybe move the following block to reconcile_attrs and refactor + elif attr == 'standard_name' and 'long_name' in ds_var.attrs: + ds_var.attrs[attr] = ds_var.attrs['long_name'] + elif attr == 'standard_name' and 'long_name' in ds_var.encoding: + ds_var.attrs[attr] = ds_var.encoding['long_name'] else: ds_var.attrs[attr] = ATTR_NOT_FOUND if ds_var.attrs[attr] is ATTR_NOT_FOUND: @@ -1240,7 +1245,18 @@ def check_ds_attrs(self, var, ds): # Only check attributes on the dependent variable var_name and its # coordinates. tv_name = var.translation.name - names_to_check = [tv_name] + list(ds[tv_name].dims) + if tv_name in ds.variables: + names_to_check = [tv_name] + list(ds[tv_name].dims) + # try searching for 4-D field instead of variable name for a specific level + # (e.g., U instead of U500) + elif len(var.translation.scalar_coords) > 0: + var_basename = ''.join(filter(lambda x: not x.isdigit(), tv_name)) + if var_basename in ds.variables: + names_to_check = [var_basename] + list(ds[var_basename].dims) + else: + raise util.MetadataError(f'Did not find variable key {tv_name} or {var_basename}' + f'in queried xarray dataset.') + for v_name in names_to_check: try: self.check_metadata(ds[v_name], 'standard_name', 'units')