Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor catalog write #587

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mdtf_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def backup_config(config):
# rename vars in cat_subset to align with POD convention
cat_subset = data_pp.rename_dataset_vars(cat_subset, cases)
# write the ESM intake catalog for the preprocessed files
data_pp.write_pp_catalog(cat_subset, model_paths, log.log)
data_pp.write_pp_catalog(cases, cat_subset, model_paths, log.log)
# configure the runtime environments and run the POD(s)
if not any(p.failed for p in pods.values()):
log.log.info("### %s: running pods '%s'.", [p for p in pods.keys()])
Expand Down
155 changes: 2 additions & 153 deletions src/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,6 @@
# RegexPattern that matches any string (path) that doesn't end with ".nc".
ignore_non_nc_regex = util.RegexPattern(r".*(?<!\.nc)")

sample_data_regex = util.RegexPattern(
r"""
(?P<sample_dataset>\S+)/ # first directory: model name
(?P<frequency>\w+)/ # subdirectory: data frequency
# file name = model name + variable name + frequency
(?P=sample_dataset)\.(?P<variable>\w+)\.(?P=frequency)\.nc
""",
input_field="remote_path",
match_error_filter=ignore_non_nc_regex
)


@util.regex_dataclass(sample_data_regex)
class SampleDataFile:
"""Dataclass describing catalog entries for sample model data files.
"""
sample_dataset: str = util.MANDATORY
frequency: util.DateFrequency = util.MANDATORY
variable: str = util.MANDATORY
remote_path: str = util.MANDATORY


@util.mdtf_dataclass
class DataSourceAttributesBase:
"""Class defining attributes that any DataSource needs to specify:

- *CASENAME*: User-supplied label to identify output of this run of the
package.
- *FIRSTYR*, *LASTYR*, *date_range*: Analysis period, specified as a closed
interval (i.e. running from 1 Jan of FIRSTYR through 31 Dec of LASTYR).
- *CASE_ROOT_DIR*: Root directory containing input model data. Different
DataSources may interpret this differently.
- *convention*: name of the variable naming convention used by the source of
model data.
"""
CASENAME: str = util.MANDATORY
FIRSTYR: str = util.MANDATORY
LASTYR: str = util.MANDATORY
date_range: util.DateRange = dataclasses.field(init=False)
CASE_ROOT_DIR: str = ""

log: dataclasses.InitVar = _log

def _set_case_root_dir(self, log=_log):
config = {}
if not self.CASE_ROOT_DIR and config.CASE_ROOT_DIR:
log.debug("Using global CASE_ROOT_DIR = '%s'.", config.CASE_ROOT_DIR)
self.CASE_ROOT_DIR = config.CASE_ROOT_DIR
# verify case root dir exists
if not os.path.isdir(self.CASE_ROOT_DIR):
log.critical("Data directory CASE_ROOT_DIR = '%s' not found.",
self.CASE_ROOT_DIR)
util.exit_handler(code=1)

def __post_init__(self, log=_log):
self._set_case_root_dir(log=log)
self.date_range = util.DateRange(self.FIRSTYR, self.LASTYR)


class DataSourceBase(util.MDTFObjectBase, util.CaseLoggerMixin):
"""DataSource for handling POD sample model data for multirun cases stored on a local filesystem.
Expand All @@ -85,7 +27,8 @@ class DataSourceBase(util.MDTFObjectBase, util.CaseLoggerMixin):
env_vars: util.WormDict()
query: dict = dict(frequency="",
path="",
standard_name=""
standard_name="",
realm=""
)

def __init__(self, case_name: str,
Expand Down Expand Up @@ -152,80 +95,6 @@ class CMIPDataSource(DataSourceBase):
# varlist = diagnostic.varlist
convention: str = "CMIP"

## NOTE: the __post_init__ method below is retained for reference in case
## we need to define all CMIP6 DRS attributes for the catalog query
#def __post_init__(self, log=_log, model=None, experiment=None):
# config = {}
# cv = cmip6.CMIP6_CVs()

# def _init_x_from_y(source, dest):
# if not getattr(self, dest, ""):
# try:
# source_val = getattr(self, source, "")
# if not source_val:
# raise KeyError()
# dest_val = cv.lookup_single(source_val, source, dest)
# log.debug("Set %s='%s' based on %s='%s'.",
# dest, dest_val, source, source_val)
# setattr(self, dest, dest_val)
# except KeyError:
# log.debug("Couldn't set %s from %s='%s'.",
# dest, source, source_val)
# setattr(self, dest, "")

# if not self.CASE_ROOT_DIR and config.CASE_ROOT_DIR:
# log.debug("Using global CASE_ROOT_DIR = '%s'.", config.CASE_ROOT_DIR)
# self.CASE_ROOT_DIR = config.CASE_ROOT_DIR
# verify case root dir exists
# if not os.path.isdir(self.CASE_ROOT_DIR):
# log.critical("Data directory CASE_ROOT_DIR = '%s' not found.",
# self.CASE_ROOT_DIR)
# util.exit_handler(code=1)

# should really fix this at the level of CLI flag synonyms
# if model and not self.source_id:
# self.source_id = model
# if experiment and not self.experiment_id:
# self.experiment_id = experiment

# # validate non-empty field values
# for field in dataclasses.fields(self):
# val = getattr(self, field.name, "")
# if not val:
# continue
# try:
# if not cv.is_in_cv(field.name, val):
# log.error(("Supplied value '%s' for '%s' is not recognized by "
# "the CMIP6 CV. Continuing, but queries will probably fail."),
# val, field.name)
# except KeyError:
# # raised if not a valid CMIP6 CV category
# continue
# # currently no inter-field consistency checks: happens implicitly, since
# # set_experiment will find zero experiments.

# # Attempt to determine first few fields of DRS, to avoid having to crawl
# # entire DRS structure
# _init_x_from_y('experiment_id', 'activity_id')
# _init_x_from_y('source_id', 'institution_id')
# _init_x_from_y('institution_id', 'source_id')
# # TODO: multi-column lookups
# # set CATALOG_DIR to be further down the hierarchy if possible, to
# # avoid having to crawl entire DRS structure; CASE_ROOT_DIR remains the
# # root of the DRS hierarchy
# new_root = self.CASE_ROOT_DIR
# for drs_attr in ("activity_id", "institution_id", "source_id", "experiment_id"):
# drs_val = getattr(self, drs_attr, "")
# if not drs_val:
# break
# new_root = os.path.join(new_root, drs_val)
# if not os.path.isdir(new_root):
# log.error("Data directory '%s' not found; starting crawl at '%s'.",
# new_root, self.CASE_ROOT_DIR)
# self.CATALOG_DIR = self.CASE_ROOT_DIR
# else:
# self.CATALOG_DIR = new_root


@data_source.maker
class CESMDataSource(DataSourceBase):
Expand All @@ -247,23 +116,3 @@ class GFDLDataSource(DataSourceBase):
# col_spec = sampleLocalFileDataSource_col_spec
# varlist = diagnostic.varlist
convention: str = "GFDL"



dummy_regex = util.RegexPattern(
r"""(?P<dummy_group>.*) # match everything; RegexPattern needs >= 1 named groups
""",
input_field="remote_path",
match_error_filter=ignore_non_nc_regex
)


@util.regex_dataclass(dummy_regex)
class GlobbedDataFile:
"""Applies a trivial regex to the paths returned by the glob."""
dummy_group: str = util.MANDATORY
remote_path: str = util.MANDATORY




73 changes: 40 additions & 33 deletions src/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,34 +875,35 @@ def query_catalog(self,
# path_regex = '*' + case_name + '*'
freq = case_d.varlist.T.frequency.format()

for v in case_d.varlist.iter_vars():
realm_regex = v.realm + '*'
date_range = v.translation.T.range
for var in case_d.varlist.iter_vars():
realm_regex = var.realm + '*'
date_range = var.translation.T.range
# define initial query dictionary with variable settings requirements that do not change if
# the variable is translated
# TODO: add method to convert freq from DateFrequency object to string
case_d.query['frequency'] = freq
case_d.query['path'] = [path_regex]
case_d.query['variable_id'] = v.translation.name
# search translation for further query requirements
for q in case_d.query:
if hasattr(v.translation, q):
case_d.query.update({q: getattr(v.translation, q)})
case_d.query['variable_id'] = var.translation.name
case_d.query['realm'] = realm_regex
case_d.query['standard_name'] = var.translation.standard_name

# change realm key name if necessary
if cat.df.get('modeling_realm', None) is not None:
case_d.query['modeling_realm'] = case_d.query.pop('realm')

# search catalog for convention specific query object
cat_subset = cat.search(**case_d.query)
if cat_subset.df.empty:
# check whether there is an alternate variable to substitute
if any(v.alternates):
if any(var.alternates):
try_new_query = True
for a in v.alternates:
for a in var.alternates:
case_d.query.update({'variable_id': a.name})
if any(v.translation.scalar_coords):
if any(var.translation.scalar_coords):
found_z_entry = False
# check for vertical coordinate to determine if level extraction is needed
for c in a.scalar_coords:
if c.axis == 'Z':
v.translation.requires_level_extraction = True
var.translation.requires_level_extraction = True
found_z_entry = True
break
else:
Expand Down Expand Up @@ -1257,6 +1258,7 @@ def process(self,
return cat_subset

def write_pp_catalog(self,
cases: dict,
input_catalog_ds: xr.Dataset,
config: util.PodPathManager,
log: logging.log):
Expand All @@ -1267,26 +1269,32 @@ def write_pp_catalog(self,
pp_cat_assets = util.define_pp_catalog_assets(config, cat_file_name)
file_list = util.get_file_list(config.OUTPUT_DIR)
# fill in catalog information from pp file name
entries = [e.cat_entry for e in list(map(util.catalog.ppParser['ppParser' + 'GFDL'], file_list))]
# append columns defined in assets
columns = [att['column_name'] for att in pp_cat_assets['attributes']]
for col in columns:
for e in entries:
if col not in e.keys():
e[col] = ""
# copy information from input catalog to pp catalog entries
global_attrs = ['convention', 'realm']
for e in entries:
ds_match = input_catalog_ds[e['dataset_name']]
for att in global_attrs:
e[att] = ds_match.attrs.get(att, '')
ds_var = ds_match.data_vars.get(e['variable_id'])
for key, val in ds_var.attrs.items():
if key in columns:
e[key] = val

# create a Pandas dataframe rom the the catalog entries
cat_df = pd.DataFrame(entries)
cat_entries = []
# each key is a case
for case_name, case_dict in cases.items():
ds_match = input_catalog_ds[case_name]
for var in case_dict.varlist.iter_vars():
ds_var = ds_match.data_vars.get(var.translation.name, None)
if ds_var is None:
log.error(f'No var {var.translation.name}')
d = dict.fromkeys(columns, "")
for key, val in ds_match.attrs.items():
if 'intake_esm_attrs' in key:
for c in columns:
if key.split('intake_esm_attrs:')[1] == c:
d[c] = val
if var.translation.convention == 'no_translation':
d.update({'convention': var.convention})
else:
d.update({'convention': var.translation.convention})
d.update({'path': var.dest_path})
cat_entries.append(d)

# create a Pandas dataframe romthe catalog entries

cat_df = pd.DataFrame(cat_entries)
cat_df.head()
# validate the catalog
try:
Expand All @@ -1298,7 +1306,7 @@ def write_pp_catalog(self,
)
)
except Exception as exc:
log.error(f'Unable to validate esm intake catalog for pp data: {exc}')
log.error(f'Error validating ESM intake catalog for pp data: {exc}')
try:
log.debug(f'Writing pp data catalog {cat_file_name} csv and json files to {config.OUTPUT_DIR}')
validated_cat.serialize(cat_file_name,
Expand Down Expand Up @@ -1334,7 +1342,6 @@ def process(self, case_list: dict,
cat_subset = self.query_catalog(case_list, config.DATA_CATALOG)
for case_name, case_xr_dataset in cat_subset.items():
for v in case_list[case_name].varlist.iter_vars():
self.edit_request(v, convention=cat_subset[case_name].convention)
cat_subset[case_name] = self.parse_ds(v, case_xr_dataset)

return cat_subset
Expand Down
Loading
Loading