diff --git a/ldcpy/__init__.py b/ldcpy/__init__.py index d2cee81..37a1eac 100644 --- a/ldcpy/__init__.py +++ b/ldcpy/__init__.py @@ -4,13 +4,14 @@ from pkg_resources import DistributionNotFound, get_distribution from .calcs import Datasetcalcs, Diffcalcs -from .collect_datasets import collect_datasets + +# from .collect_datasets import collect_datasets from .comp_checker import CompChecker from .derived_vars import cam_budgets from .plot import plot from .util import ( check_metrics, - combine_datasets, + collect_datasets, compare_stats, open_datasets, save_metrics, diff --git a/ldcpy/calcs.py b/ldcpy/calcs.py index fc4347c..569b77b 100644 --- a/ldcpy/calcs.py +++ b/ldcpy/calcs.py @@ -28,7 +28,7 @@ from skimage.util import crop from xrft import dft -from .collect_datasets import collect_datasets +# from .collect_datasets import collect_datasets xr.set_options(keep_attrs=True) @@ -43,7 +43,7 @@ def __init__( ds: xr.DataArray, data_type: str, aggregate_dims: list, - time_dim_name: str = 'time', + time_dim_name: str = None, lat_dim_name: str = None, lon_dim_name: str = None, vert_dim_name: str = None, @@ -72,6 +72,7 @@ def __init__( lat_coord_name = ds.cf.coordinates['latitude'][0] self._lat_coord_name = lat_coord_name + # WARNING WRF ALSO HAS XLAT_U and XLONG_U, XLAT_v and XLONG_V dd = ds.cf[ds.cf.coordinates['latitude'][0]].dims ll = len(dd) @@ -85,6 +86,11 @@ def __init__( lat_dim_name = dd[0] if lon_dim_name is None: lon_dim_name = dd[1] + elif data_type == 'wrf': + if lat_dim_name is None: + lat_dim_name = dd[0] + if lon_dim_name is None: + lon_dim_name = dd[1] else: print('Warning: unknown data_type: ', data_type) @@ -94,12 +100,19 @@ def __init__( # vertical dimension? if vert_dim_name is None: - vert = 'vertical' in ds.cf - if vert: - vert_dim_name = ds.cf['vertical'].name + if data_type == 'wrf': + vert = 'z' in ds.dims + if vert: + vert_dim_name = 'z' + else: + vert = 'vertical' in ds.cf + if vert: + vert_dim_name = ds.cf['vertical'].name self._vert_dim_name = vert_dim_name # time dimension TO DO: check this (after cf_xarray update) + if time_dim_name is None: + time_dim_name = ds.cf.coordinates['time'] self._time_dim_name = time_dim_name self._quantile = q @@ -1493,9 +1506,6 @@ def get_calc_ds(self, calc_name: str, var_name: str) -> xr.Dataset: da = self.get_calc(calc_name) ds = da.squeeze().to_dataset(name=var_name, promote_attrs=True) ds.attrs['data_type'] = da.data_type - # new_ds = collect_datasets(self._ds.data_type, [var_name], [ds], - # [self._ds.set_name]) - # new_ds = new_ds.astype(self.dtype) return ds def get_calc(self, name: str, q: Optional[int] = 0.5, grouping: Optional[str] = None, ddof=1): @@ -2074,16 +2084,9 @@ def ssim_value(self): This creates two plots and uses the standard SSIM. """ - # import tempfile - - # import skimage.io - # import skimage.metrics - # from skimage.metrics import structural_similarity as ssim - k1 = self._k1 k2 = self._k2 - # if not self._is_memoized('_ssim_value'): if True: # Prevent showing stuff backend_ = mpl.get_backend() @@ -2101,25 +2104,28 @@ def ssim_value(self): central = 0.0 # might make this a parameter later if self._data_type == 'pop': central = 300.0 - # make periodic + # make periodic for pop or cam-fv if self._data_type == 'pop': cy_lon1 = np.hstack((lon1, lon1[:, 0:1])) cy_lon2 = np.hstack((lon2, lon2[:, 0:1])) - cy_lat1 = np.hstack((lat1, lat1[:, 0:1])) cy_lat2 = np.hstack((lat2, lat2[:, 0:1])) - cy1 = add_cyclic_point(d1) cy2 = add_cyclic_point(d2) + no_inf_d1 = np.nan_to_num(cy1, nan=np.nan) + no_inf_d2 = np.nan_to_num(cy2, nan=np.nan) - else: # cam-fv + elif self._data_type == 'cam-fv': # cam-fv cy1, cy_lon1 = add_cyclic_point(d1, coord=lon1) cy2, cy_lon2 = add_cyclic_point(d2, coord=lon2) cy_lat1 = lat1 cy_lat2 = lat2 + no_inf_d1 = np.nan_to_num(cy1, nan=np.nan) + no_inf_d2 = np.nan_to_num(cy2, nan=np.nan) - no_inf_d1 = np.nan_to_num(cy1, nan=np.nan) - no_inf_d2 = np.nan_to_num(cy2, nan=np.nan) + elif self._data_type == 'wrf': + no_inf_d1 = np.nan_to_num(d1, nan=np.nan) + no_inf_d2 = np.nan_to_num(d2, nan=np.nan) # is it 3D? must do each level if self._calcs1._vert_dim_name is not None: diff --git a/ldcpy/collect_datasets.py b/ldcpy/collect_datasets.py index e8a0442..0dd7e57 100644 --- a/ldcpy/collect_datasets.py +++ b/ldcpy/collect_datasets.py @@ -6,7 +6,7 @@ def preprocess(ds, varnames): return ds[varnames] -def collect_datasets(data_type, varnames, list_of_ds, labels, **kwargs): +def OLD_collect_datasets(data_type, varnames, list_of_ds, labels, **kwargs): """ Concatenate several different xarray datasets across a new "collection" dimension, which can be accessed with the specified diff --git a/ldcpy/plot.py b/ldcpy/plot.py index b86fa74..2f60d2f 100644 --- a/ldcpy/plot.py +++ b/ldcpy/plot.py @@ -159,7 +159,7 @@ def get_calcs(self, da, data_type): if data_type == 'cam-fv': # 1d lat_dim = dd[0] lon_dim = da_data.cf['longitude'].dims[0] - elif data_type == 'pop': # 2d + elif data_type == 'pop' or data_type == 'wrf': # 2d lat_dim = dd[0] lon_dim = dd[1] @@ -318,6 +318,7 @@ def spatial_plot(self, da_sets, titles, data_type): cmin = [] # lat/lon could be 1 or 2d and have different names + # TO - will need to adjust for WRF for U and V? lon_coord_name = da_sets[0].cf.coordinates['longitude'][0] lat_coord_name = da_sets[0].cf.coordinates['latitude'][0] @@ -327,16 +328,18 @@ def spatial_plot(self, da_sets, titles, data_type): if data_type == 'pop': central = 300.0 + # projection: + if data_type == 'wrf': + myproj = ccrs.PlateCarree() + else: + myproj = ccrs.Robinson(central_longitude=central) + for i in range(da_sets.sets.size): if self.vert_plot: - axs[i] = plt.subplot( - nrows, 1, i + 1, projection=ccrs.Robinson(central_longitude=central) - ) + axs[i] = plt.subplot(nrows, 1, i + 1, projection=myproj) else: - axs[i] = plt.subplot( - nrows, ncols, i + 1, projection=ccrs.Robinson(central_longitude=central) - ) + axs[i] = plt.subplot(nrows, ncols, i + 1, projection=myproj) axs[i].set_facecolor('#39ff14') @@ -355,6 +358,10 @@ def spatial_plot(self, da_sets, titles, data_type): lat_sets = da_sets[i][lat_coord_name] cy_datas = add_cyclic_point(da_sets[i]) + elif data_type == 'wrf': + lat_sets = da_sets[i][lat_coord_name] + lon_sets = da_sets[i][lon_coord_name] + cy_datas = da_sets[i] if np.isnan(cy_datas).any() or np.isinf(cy_datas).any(): nan_inf_flag = 1 @@ -388,7 +395,7 @@ def spatial_plot(self, da_sets, titles, data_type): # casting to float32 from float64 using imshow prevents lots of tiny black dots from showing up in some plots with lots of # zeroes. See plot of probability of negative PRECT to see this in action. - if data_type == 'pop': + if data_type == 'pop' or data_type == 'wrf': psets[i] = psets[i] = axs[i].pcolormesh( lon_sets, lat_sets, @@ -406,7 +413,7 @@ def spatial_plot(self, da_sets, titles, data_type): axs[i].set_global() - if data_type == 'cam-fv': + if data_type == 'cam-fv' or data_type == 'wrf': axs[i].coastlines() elif data_type == 'pop': axs[i].add_feature( @@ -701,7 +708,7 @@ def get_calc_label(self, calc, data, data_type): if data_type == 'cam-fv': # 1D lat_dim = dd[0] lon_dim = data.cf['longitude'].dims[0] - elif data_type == 'pop': # 2D + elif data_type == 'pop' or data_type == 'wrf': # 2D lat_dim = dd[0] lon_dim = dd[1] @@ -1120,7 +1127,7 @@ class in ldcpy.plot, for more information about the available calcs see ldcpy.Da raw_calcs.append(mp.get_calcs(d, ds.data_type)) # get lat/lon coordinate names: - if ds.data_type == 'pop': + if ds.data_type == 'pop' or ds.data_type == 'wrf': lon_coord_name = datas[0].cf[datas[0].cf.coordinates['longitude'][0]].dims[1] lat_coord_name = datas[0].cf[datas[0].cf.coordinates['latitude'][0]].dims[0] elif ds.data_type == 'cam-fv': # cam-fv @@ -1143,7 +1150,7 @@ class in ldcpy.plot, for more information about the available calcs see ldcpy.Da if ds.data_type == 'cam-fv': # 1D mp.title_lat = subsets[0][lat_coord_name].data[0] mp.title_lon = subsets[0][lon_coord_name].data[0] - 180 - elif ds.data_type == 'pop': # 2D + elif ds.data_type == 'pop' or 'wrf': # 2D # lon should be 0- 360 mylat = subsets[0][lat_coord_name].data[0] mylon = subsets[0][lon_coord_name].data[0] diff --git a/ldcpy/util.py b/ldcpy/util.py index d3cace1..ba659e0 100644 --- a/ldcpy/util.py +++ b/ldcpy/util.py @@ -22,7 +22,7 @@ def collect_datasets(data_type, varnames, list_of_ds, labels, **kwargs): Parameters ========== data_type: string - Current data types: :cam-fv, pop + Current data types: :cam-fv, pop, wrf varnames : list The variable(s) of interest to combine across input files (usually just one) list_of_datasets : list @@ -39,6 +39,12 @@ def collect_datasets(data_type, varnames, list_of_ds, labels, **kwargs): out : xarray.Dataset a collection containing all the data from the list datasets + Notes + ====== + -WRF data must be postprocessed with xWRF before passing to ldcpy + (e.g., ds = xr.open_dataset(wrf_file, engine="netcdf4").xwrf.postprocess()) + -For now lat/lon info must be in the same file! + """ # Error checking: # list_of_files and labels must be same length @@ -49,16 +55,31 @@ def collect_datasets(data_type, varnames, list_of_ds, labels, **kwargs): # the number of timeslices must be the same sz = np.zeros(len(list_of_ds)) for i, myds in enumerate(list_of_ds): - sz[i] = myds.sizes['time'] + time_name = myds.cf.coordinates['time'][0] + sz[i] = myds.sizes[time_name] indx = np.unique(sz) assert indx.size == 1, 'ERROR: all datasets must have the same length time dimension' + # wrf data must contain lat/lon info in same file (for now) + if data_type == 'wrf': + latlon_found = np.zeros(len(list_of_ds)) + for i, myds in enumerate(list_of_ds): + # XLAT,XLONG,XLAT_U,XLONG_U,XLAT_V,XLONG_V + for j in myds.coords.keys(): + if j == 'XLAT' or j == 'XLONG': + latlon_found[i] += 1 + indx = np.where(latlon_found > 1)[0] + assert len(indx) == len(list_of_ds), 'ERROR: WRF datasets must contain XLAT and XLONG' + + # weights? if data_type == 'cam-fv': weights_name = 'gw' varnames.append(weights_name) elif data_type == 'pop': weights_name = 'TAREA' varnames.append(weights_name) + elif data_type == 'wrf': + weights_name = None # preprocess_vars is here for working on jupyter hub... def preprocess_vars(ds, varnames): @@ -72,7 +93,7 @@ def preprocess_vars(ds, varnames): if data_type == 'pop': full_ds.coords['cell_area'] = xr.DataArray(full_ds.variables.mapping.get(weights_name))[0] - else: + elif data_type == 'cam-fv': full_ds.coords['cell_area'] = ( xr.DataArray(full_ds.variables.mapping.get(weights_name)) .expand_dims(lon=full_ds.dims['lon']) @@ -81,7 +102,8 @@ def preprocess_vars(ds, varnames): full_ds.attrs['cell_measures'] = 'area: cell_area' - full_ds = full_ds.drop(weights_name) + if weights_name: + full_ds = full_ds.drop(weights_name) full_ds['collection'] = xr.DataArray(labels, dims='collection') @@ -89,6 +111,19 @@ def preprocess_vars(ds, varnames): full_ds.attrs['data_type'] = data_type full_ds.attrs['file_size'] = None + # from other copy of this function + for v in varnames[:-1]: + new_ds = [] + i = 0 + for label in labels: + new_ds.append(full_ds[v].sel(collection=label)) + new_ds[i].attrs['data_type'] = data_type + new_ds[i].attrs['set_name'] = label + + # d = xr.combine_by_coords(new_ds) + d = xr.concat(new_ds, 'collection') + full_ds[v] = d + return full_ds @@ -107,6 +142,7 @@ def open_datasets(data_type, varnames, list_of_files, labels, weights=True, **kw labels. Stores them in an xarray dataset which can be passed to the ldcpy plot functions. + Parameters ========== data_type: string @@ -127,6 +163,12 @@ def open_datasets(data_type, varnames, list_of_files, labels, weights=True, **kw out : xarray.Dataset a collection containing all the data from the list of files + + Notes + ====== + wrf netcdf data must be postprocessed with xwrf, e.g. + ds = xr.open_dataset(wrf_file, engine="netcdf4").xwrf.postprocess() + So need to use collect_data instead. """ # Error checking: @@ -134,6 +176,10 @@ def open_datasets(data_type, varnames, list_of_files, labels, weights=True, **kw assert len(list_of_files) == len( labels ), 'ERROR: open_dataset file list and labels arguments must be the same length' + # can't use wrf wwith this function + assert ( + data_type != 'wrf' + ), 'ERROR: WRF files must be postprocessed with xWRF and passed to collect_dataset' # all must have the same time dimension sz = np.zeros(len(list_of_files)) @@ -163,6 +209,9 @@ def preprocess_vars(ds): elif data_type == 'pop' and weights is True: weights_name = 'TAREA' varnames.append(weights_name) + elif data_type == 'wrf': + weights = False + weights_name = None full_ds = xr.open_mfdataset( list_of_files, @@ -176,7 +225,7 @@ def preprocess_vars(ds): if data_type == 'pop' and weights is True: full_ds.coords['cell_area'] = xr.DataArray(full_ds.variables.mapping.get(weights_name))[0] - elif weights is True: + elif data_type == 'cam-fv' and weights is True: full_ds.coords['cell_area'] = ( xr.DataArray(full_ds.variables.mapping.get(weights_name)) .expand_dims(lon=full_ds.dims['lon']) @@ -250,6 +299,10 @@ def compare_stats( da = ds[varname] data_type = ds.attrs['data_type'] + # no weights for wrf + if data_type == 'wrf': + weighted = False + file_size_dict = ds.attrs['file_size'] if file_size_dict is None: include_file_size = False