Skip to content

Commit

Permalink
wrf additions
Browse files Browse the repository at this point in the history
  • Loading branch information
allibco committed Sep 16, 2024
1 parent ea6b748 commit 96a06a4
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 41 deletions.
5 changes: 3 additions & 2 deletions ldcpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from pkg_resources import DistributionNotFound, get_distribution

from .calcs import Datasetcalcs, Diffcalcs
from .collect_datasets import collect_datasets

# from .collect_datasets import collect_datasets
from .comp_checker import CompChecker
from .derived_vars import cam_budgets
from .plot import plot
from .util import (
check_metrics,
combine_datasets,
collect_datasets,
compare_stats,
open_datasets,
save_metrics,
Expand Down
48 changes: 27 additions & 21 deletions ldcpy/calcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from skimage.util import crop
from xrft import dft

from .collect_datasets import collect_datasets
# from .collect_datasets import collect_datasets

xr.set_options(keep_attrs=True)

Expand All @@ -43,7 +43,7 @@ def __init__(
ds: xr.DataArray,
data_type: str,
aggregate_dims: list,
time_dim_name: str = 'time',
time_dim_name: str = None,
lat_dim_name: str = None,
lon_dim_name: str = None,
vert_dim_name: str = None,
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(
lat_coord_name = ds.cf.coordinates['latitude'][0]
self._lat_coord_name = lat_coord_name

# WARNING WRF ALSO HAS XLAT_U and XLONG_U, XLAT_v and XLONG_V
dd = ds.cf[ds.cf.coordinates['latitude'][0]].dims

ll = len(dd)
Expand All @@ -85,6 +86,11 @@ def __init__(
lat_dim_name = dd[0]
if lon_dim_name is None:
lon_dim_name = dd[1]
elif data_type == 'wrf':
if lat_dim_name is None:
lat_dim_name = dd[0]
if lon_dim_name is None:
lon_dim_name = dd[1]
else:
print('Warning: unknown data_type: ', data_type)

Expand All @@ -94,12 +100,19 @@ def __init__(

# vertical dimension?
if vert_dim_name is None:
vert = 'vertical' in ds.cf
if vert:
vert_dim_name = ds.cf['vertical'].name
if data_type == 'wrf':
vert = 'z' in ds.dims
if vert:
vert_dim_name = 'z'
else:
vert = 'vertical' in ds.cf
if vert:
vert_dim_name = ds.cf['vertical'].name
self._vert_dim_name = vert_dim_name

# time dimension TO DO: check this (after cf_xarray update)
if time_dim_name is None:
time_dim_name = ds.cf.coordinates['time']
self._time_dim_name = time_dim_name

self._quantile = q
Expand Down Expand Up @@ -1493,9 +1506,6 @@ def get_calc_ds(self, calc_name: str, var_name: str) -> xr.Dataset:
da = self.get_calc(calc_name)
ds = da.squeeze().to_dataset(name=var_name, promote_attrs=True)
ds.attrs['data_type'] = da.data_type
# new_ds = collect_datasets(self._ds.data_type, [var_name], [ds],
# [self._ds.set_name])
# new_ds = new_ds.astype(self.dtype)
return ds

def get_calc(self, name: str, q: Optional[int] = 0.5, grouping: Optional[str] = None, ddof=1):
Expand Down Expand Up @@ -2074,16 +2084,9 @@ def ssim_value(self):
This creates two plots and uses the standard SSIM.
"""

# import tempfile

# import skimage.io
# import skimage.metrics
# from skimage.metrics import structural_similarity as ssim

k1 = self._k1
k2 = self._k2

# if not self._is_memoized('_ssim_value'):
if True:
# Prevent showing stuff
backend_ = mpl.get_backend()
Expand All @@ -2101,25 +2104,28 @@ def ssim_value(self):
central = 0.0 # might make this a parameter later
if self._data_type == 'pop':
central = 300.0
# make periodic
# make periodic for pop or cam-fv
if self._data_type == 'pop':
cy_lon1 = np.hstack((lon1, lon1[:, 0:1]))
cy_lon2 = np.hstack((lon2, lon2[:, 0:1]))

cy_lat1 = np.hstack((lat1, lat1[:, 0:1]))
cy_lat2 = np.hstack((lat2, lat2[:, 0:1]))

cy1 = add_cyclic_point(d1)
cy2 = add_cyclic_point(d2)
no_inf_d1 = np.nan_to_num(cy1, nan=np.nan)
no_inf_d2 = np.nan_to_num(cy2, nan=np.nan)

else: # cam-fv
elif self._data_type == 'cam-fv': # cam-fv
cy1, cy_lon1 = add_cyclic_point(d1, coord=lon1)
cy2, cy_lon2 = add_cyclic_point(d2, coord=lon2)
cy_lat1 = lat1
cy_lat2 = lat2
no_inf_d1 = np.nan_to_num(cy1, nan=np.nan)
no_inf_d2 = np.nan_to_num(cy2, nan=np.nan)

no_inf_d1 = np.nan_to_num(cy1, nan=np.nan)
no_inf_d2 = np.nan_to_num(cy2, nan=np.nan)
elif self._data_type == 'wrf':
no_inf_d1 = np.nan_to_num(d1, nan=np.nan)
no_inf_d2 = np.nan_to_num(d2, nan=np.nan)

# is it 3D? must do each level
if self._calcs1._vert_dim_name is not None:
Expand Down
2 changes: 1 addition & 1 deletion ldcpy/collect_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def preprocess(ds, varnames):
return ds[varnames]


def collect_datasets(data_type, varnames, list_of_ds, labels, **kwargs):
def OLD_collect_datasets(data_type, varnames, list_of_ds, labels, **kwargs):
"""
Concatenate several different xarray datasets across a new
"collection" dimension, which can be accessed with the specified
Expand Down
31 changes: 19 additions & 12 deletions ldcpy/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def get_calcs(self, da, data_type):
if data_type == 'cam-fv': # 1d
lat_dim = dd[0]
lon_dim = da_data.cf['longitude'].dims[0]
elif data_type == 'pop': # 2d
elif data_type == 'pop' or data_type == 'wrf': # 2d
lat_dim = dd[0]
lon_dim = dd[1]

Expand Down Expand Up @@ -318,6 +318,7 @@ def spatial_plot(self, da_sets, titles, data_type):
cmin = []

# lat/lon could be 1 or 2d and have different names
# TO - will need to adjust for WRF for U and V?
lon_coord_name = da_sets[0].cf.coordinates['longitude'][0]
lat_coord_name = da_sets[0].cf.coordinates['latitude'][0]

Expand All @@ -327,16 +328,18 @@ def spatial_plot(self, da_sets, titles, data_type):
if data_type == 'pop':
central = 300.0

# projection:
if data_type == 'wrf':
myproj = ccrs.PlateCarree()
else:
myproj = ccrs.Robinson(central_longitude=central)

for i in range(da_sets.sets.size):

if self.vert_plot:
axs[i] = plt.subplot(
nrows, 1, i + 1, projection=ccrs.Robinson(central_longitude=central)
)
axs[i] = plt.subplot(nrows, 1, i + 1, projection=myproj)
else:
axs[i] = plt.subplot(
nrows, ncols, i + 1, projection=ccrs.Robinson(central_longitude=central)
)
axs[i] = plt.subplot(nrows, ncols, i + 1, projection=myproj)

axs[i].set_facecolor('#39ff14')

Expand All @@ -355,6 +358,10 @@ def spatial_plot(self, da_sets, titles, data_type):
lat_sets = da_sets[i][lat_coord_name]

cy_datas = add_cyclic_point(da_sets[i])
elif data_type == 'wrf':
lat_sets = da_sets[i][lat_coord_name]
lon_sets = da_sets[i][lon_coord_name]
cy_datas = da_sets[i]

if np.isnan(cy_datas).any() or np.isinf(cy_datas).any():
nan_inf_flag = 1
Expand Down Expand Up @@ -388,7 +395,7 @@ def spatial_plot(self, da_sets, titles, data_type):

# casting to float32 from float64 using imshow prevents lots of tiny black dots from showing up in some plots with lots of
# zeroes. See plot of probability of negative PRECT to see this in action.
if data_type == 'pop':
if data_type == 'pop' or data_type == 'wrf':
psets[i] = psets[i] = axs[i].pcolormesh(
lon_sets,
lat_sets,
Expand All @@ -406,7 +413,7 @@ def spatial_plot(self, da_sets, titles, data_type):

axs[i].set_global()

if data_type == 'cam-fv':
if data_type == 'cam-fv' or data_type == 'wrf':
axs[i].coastlines()
elif data_type == 'pop':
axs[i].add_feature(
Expand Down Expand Up @@ -701,7 +708,7 @@ def get_calc_label(self, calc, data, data_type):
if data_type == 'cam-fv': # 1D
lat_dim = dd[0]
lon_dim = data.cf['longitude'].dims[0]
elif data_type == 'pop': # 2D
elif data_type == 'pop' or data_type == 'wrf': # 2D
lat_dim = dd[0]
lon_dim = dd[1]

Expand Down Expand Up @@ -1120,7 +1127,7 @@ class in ldcpy.plot, for more information about the available calcs see ldcpy.Da
raw_calcs.append(mp.get_calcs(d, ds.data_type))

# get lat/lon coordinate names:
if ds.data_type == 'pop':
if ds.data_type == 'pop' or ds.data_type == 'wrf':
lon_coord_name = datas[0].cf[datas[0].cf.coordinates['longitude'][0]].dims[1]
lat_coord_name = datas[0].cf[datas[0].cf.coordinates['latitude'][0]].dims[0]
elif ds.data_type == 'cam-fv': # cam-fv
Expand All @@ -1143,7 +1150,7 @@ class in ldcpy.plot, for more information about the available calcs see ldcpy.Da
if ds.data_type == 'cam-fv': # 1D
mp.title_lat = subsets[0][lat_coord_name].data[0]
mp.title_lon = subsets[0][lon_coord_name].data[0] - 180
elif ds.data_type == 'pop': # 2D
elif ds.data_type == 'pop' or 'wrf': # 2D
# lon should be 0- 360
mylat = subsets[0][lat_coord_name].data[0]
mylon = subsets[0][lon_coord_name].data[0]
Expand Down
Loading

0 comments on commit 96a06a4

Please sign in to comment.