From f34b6f4363394afcd0d3faf014c170be3dd58ea1 Mon Sep 17 00:00:00 2001 From: Jess <20195932+wrongkindofdoctor@users.noreply.github.com> Date: Tue, 9 Jul 2024 15:03:04 -0400 Subject: [PATCH] Add alternate support (#604) * add function new_dict_wo_key to util/basic.py * add boolean attribute is_alternate to VarlistEntry class remove iter_shallow_alternates function from varlist_util.py refactor alternates attribute so that it is a list of VarlistEntry classes instead of a nested list add call to new_dict_wo_key to remove alternate VarlistEntry keys from Varlist dict after they have been appended as attributes to the variable VarlistEntry that they can be substituted for * refine alternate check in preprocessor query_catalog to use translation info attached to alternates if avaiable need to fix issues with errors thrown by metadata checks for cases where 4d field is present, but doesn't match query for specific level * add varlist_util to commit * add logic to sub long_name for standard_name if necessary to xr_parser check_ds_attrs add check to sub 4D field name for variable level name in ds search to check_ds_attrs --- src/preprocessor.py | 11 +++++-- src/util/__init__.py | 2 +- src/util/basic.py | 7 +++++ src/varlist_util.py | 69 ++++++++++++++++++++++++++++++-------------- src/xr_parser.py | 18 +++++++++++- 5 files changed, 82 insertions(+), 25 deletions(-) diff --git a/src/preprocessor.py b/src/preprocessor.py index 2bbd7721d..cbd45a154 100644 --- a/src/preprocessor.py +++ b/src/preprocessor.py @@ -557,6 +557,7 @@ def edit_request(self, v: varlist_util.VarlistEntry, **kwargs): # add original 4D var defined in new_tv as an alternate TranslatedVarlistEntry # to query if no entries on specified levels are found in the data catalog + v.alternates.append(new_tv) return v @@ -922,7 +923,13 @@ def query_catalog(self, if any(var.alternates): try_new_query = True for a in var.alternates: - case_d.query.update({'variable_id': a.name}) + if hasattr(a, 'translation'): + if a.translation is not None: + case_d.query.update({'variable_id': a.translation.name}) + case_d.query.update({'standard_name': a.translation.standard_name}) + else: + case_d.query.update({'variable_id': a.name}) + case_d.query.update({'standard_name': a.standard_name}) if any(var.translation.scalar_coords): found_z_entry = False # check for vertical coordinate to determine if level extraction is needed @@ -1275,7 +1282,7 @@ def process(self, cat_subset = self.query_catalog(case_list, config.DATA_CATALOG) for case_name, case_xr_dataset in cat_subset.items(): for v in case_list[case_name].varlist.iter_vars(): - tv_name = v.translation.name #abbreviate + tv_name = v.translation.name var_xr_dataset = self.parse_ds(v, case_xr_dataset) cat_subset[case_name]['time'] = var_xr_dataset['time'] cat_subset[case_name].update({tv_name: var_xr_dataset[tv_name]}) diff --git a/src/util/__init__.py b/src/util/__init__.py index c1a34a164..ce78b0760 100644 --- a/src/util/__init__.py +++ b/src/util/__init__.py @@ -6,7 +6,7 @@ ConsistentDict, WormDefaultDict, NameSpace, MDTFEnum, sentinel_object_factory, MDTF_ID, deactivate, ObjectStatus, is_iterable, to_iter, from_iter, remove_prefix, RegexDict, - remove_suffix, filter_kwargs, splice_into_list + remove_suffix, filter_kwargs, splice_into_list, new_dict_wo_key ) from .logs import ( diff --git a/src/util/basic.py b/src/util/basic.py index a79b3b615..f3c7c0416 100644 --- a/src/util/basic.py +++ b/src/util/basic.py @@ -3,6 +3,7 @@ import abc import collections import collections.abc +import copy import enum import itertools import string @@ -709,3 +710,9 @@ def get_all_matching_values(self, queries: list): return (match for query in queries for match in self.get_matching_value(query)) +# source: https://stackoverflow.com/questions/5844672/delete-an-element-from-a-dictionary +def new_dict_wo_key(dictionary: dict, key: str) -> dict: + """Returns a **shallow** copy of the input dictionary without a key.""" + _dict = copy.copy(dictionary) + _dict.pop(key, None) + return _dict diff --git a/src/varlist_util.py b/src/varlist_util.py index a999b2f47..ab64fc5d9 100644 --- a/src/varlist_util.py +++ b/src/varlist_util.py @@ -188,6 +188,7 @@ class VarlistEntry(VarlistEntryBase, util.MDTFObjectBase, data_model.DMVariable, status: util.ObjectStatus = dc.field(default=util.ObjectStatus.NOTSET, compare=False) name: str = util.MANDATORY _parent: typing.Any = dc.field(default=util.MANDATORY, compare=False) + is_alternate: bool = False def __post_init__(self, coords=None): # set up log (VarlistEntryLoggerMixin) @@ -215,11 +216,12 @@ def __post_init__(self, coords=None): self.path_variable = self.name.upper() + _file_env_var_suffix # self.alternates is either [] or a list of nonempty lists of VEs if hasattr(self, 'alternates'): - if isinstance(self.alternates, list): - if any(self.alternates): - if not isinstance(self.alternates[0], list): - self.alternates = [self.alternates] - self.alternates = [vs for vs in self.alternates if vs] + if not isinstance(self.alternates, list): + self.alternates = [self.alternates] + else: + self.alternates = [] + if self.requirement == VarlistEntryRequirement.ALTERNATE: + self.is_alternate = True if hasattr(self, 'scalar_coords'): self.scalar_coords = self.scalar_coords @@ -332,10 +334,15 @@ def set_env_vars(self): }) for ax, dim in self.dim_axes.items(): - trans_dim = self.translation.axes[ax] - self.env_vars[dim.name + _coord_env_var_suffix] = trans_dim.name - if trans_dim.has_bounds: - self.env_vars[dim.name + _coord_bounds_env_var_suffix] = trans_dim.bounds + if self.translation is not None: + trans_dim = self.translation.axes[ax] + self.env_vars[dim.name + _coord_env_var_suffix] = trans_dim.name + if trans_dim.has_bounds: + self.env_vars[dim.name + _coord_bounds_env_var_suffix] = trans_dim.bounds + else: + self.env_vars[dim.name + _coord_env_var_suffix] = dim.name + if dim.has_bounds: + self.env_vars[dim.name + _coord_bounds_env_var_suffix] = dim.bounds def iter_alternates(self): """Breadth-first traversal of "sets" of alternate VarlistEntries, @@ -567,8 +574,30 @@ def setup_var(self, # store but don't deactivate, because preprocessor.edit_request() # may supply alternate variables v.log.store_exception(chained_exc) - # set the VarlistEntry env_vars (required for backwards compatibility with first-gen PODs) + + # set the VarlistEntry env_vars (required for backwards compatibility with first-gen PODs v.set_env_vars() + # Translate alternate vars if necessary + for alt_v in v.alternates: + if alt_v.T is not None: + alt_v.change_coord( + 'T', + new_class={ + 'self': VarlistTimeCoordinate, + 'range': util.DateRange, + 'frequency': util.DateFrequency + }, + range=date_range, + calendar=util.NOTSET, + units=util.NOTSET + ) + alt_v.dest_path = self.variable_dest_path(model_paths, case_name, alt_v) + trans_alt = translate.translate(alt_v, from_convention) + alt_v.translation = trans_alt + if trans_alt is None: + alt_v.info(f'Note: alternate variable {alt_v.full_name} not translated from {from_convention} to' + f' {to_convention}') + alt_v.set_env_vars() def variable_dest_path(self, model_paths: util.ModelDataPathManager, @@ -615,13 +644,6 @@ def _pod_dimension_from_struct(name, dd, v_settings): except Exception: raise ValueError(f"Couldn't parse dimension entry for {name}: {dd}") - def _iter_shallow_alternates(var): - """Iterator over all VarlistEntries referenced as alternates. Doesn't - traverse alternates of alternates, etc. - """ - for alt_vs in var.alternates: - yield from alt_vs - vlist_settings = util.coerce_to_dataclass( parent.pod_data, VarlistSettings) globals_d = vlist_settings.global_settings @@ -634,12 +656,17 @@ def _iter_shallow_alternates(var): } for v in vlist_vars.values(): # validate & replace names of alt vars with references to VE objects - for altv_name in _iter_shallow_alternates(v): + for altv_name in v.alternates: if altv_name not in vlist_vars: raise ValueError((f"Unknown variable name {altv_name} listed " f"in alternates for varlist entry {v.name}.")) - linked_alts = [] - for alts in v.alternates: - linked_alts.append([vlist_vars[v_name] for v_name in alts]) + linked_alts = [vlist_vars[v_name] for v_name in v.alternates] v.alternates = linked_alts + alt_vars = [k for k, v in vlist_vars.items() if v.is_alternate] + for a in alt_vars: + vlist_vars = util.new_dict_wo_key(vlist_vars, a) + + # remove alternates from VarlistEntries since they are now attributes of variable + # VarlistEntry objects that they can be substituted for + return cls(contents=list(vlist_vars.values())) diff --git a/src/xr_parser.py b/src/xr_parser.py index 3781141b9..0253285fa 100644 --- a/src/xr_parser.py +++ b/src/xr_parser.py @@ -1218,6 +1218,11 @@ def check_metadata(self, ds_var, *attr_names): if attr not in ds_var.attrs: if attr in ds_var.encoding: ds_var.attrs[attr] = ds_var.encoding[attr] + # TODO: maybe move the following block to reconcile_attrs and refactor + elif attr == 'standard_name' and 'long_name' in ds_var.attrs: + ds_var.attrs[attr] = ds_var.attrs['long_name'] + elif attr == 'standard_name' and 'long_name' in ds_var.encoding: + ds_var.attrs[attr] = ds_var.encoding['long_name'] else: ds_var.attrs[attr] = ATTR_NOT_FOUND if ds_var.attrs[attr] is ATTR_NOT_FOUND: @@ -1240,7 +1245,18 @@ def check_ds_attrs(self, var, ds): # Only check attributes on the dependent variable var_name and its # coordinates. tv_name = var.translation.name - names_to_check = [tv_name] + list(ds[tv_name].dims) + if tv_name in ds.variables: + names_to_check = [tv_name] + list(ds[tv_name].dims) + # try searching for 4-D field instead of variable name for a specific level + # (e.g., U instead of U500) + elif len(var.translation.scalar_coords) > 0: + var_basename = ''.join(filter(lambda x: not x.isdigit(), tv_name)) + if var_basename in ds.variables: + names_to_check = [var_basename] + list(ds[var_basename].dims) + else: + raise util.MetadataError(f'Did not find variable key {tv_name} or {var_basename}' + f'in queried xarray dataset.') + for v_name in names_to_check: try: self.check_metadata(ds[v_name], 'standard_name', 'units')