Skip to content

Commit

Permalink
Sparse Interpolation QoL
Browse files Browse the repository at this point in the history
  • Loading branch information
tobin-ford committed Sep 15, 2024
1 parent cd52739 commit 1b113c3
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 8 deletions.
51 changes: 43 additions & 8 deletions pvdeg/geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,14 +927,17 @@ def elevation_stochastic_downselect(


def interpolate_analysis(
result: xr.Dataset, data_var: str, method="nearest"
result: xr.Dataset, data_var: str, method="nearest", resolution=100j,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Interpolate sparse spatial result data against DataArray coordinates.
Takes DataArray instead of Dataset, index one variable of a dataset to get a dataarray.
Parameters:
-----------
resolution: complex
Change the amount the input is interpolated.
For more interpolation set higher (200j is more than 100j)
Result:
-------
Expand All @@ -951,21 +954,51 @@ def interpolate_analysis(
) # probably a nicer way to do this

grid_lat, grid_lon = np.mgrid[
df["latitude"].min() : df["latitude"].max() : 100j,
df["longitude"].min() : df["longitude"].max() : 100j,
df["latitude"].min() : df["latitude"].max() : resolution,
df["longitude"].min() : df["longitude"].max() : resolution,
]

grid_z = griddata(data[:, 0:2], data[:, 2], xi=(grid_lat, grid_lon), method=method)

return grid_z, grid_lat, grid_lon


def plot_sparse_analysis(result: xr.Dataset, data_var: str, method="nearest") -> None:
# api could be updated to match that of plot_USA
def plot_sparse_analysis(
result: xr.Dataset,
data_var: str,
method="nearest",
resolution:complex=100j,
figsize:tuple=(10,8),
) -> None:
"""
Plot the output of a sparse geospatial analysis using interpolation.
Parameters
-----------
result: xr.Dataset
xarray dataset in memory containing coordinates['longitude', 'latitude'] and at least one datavariable.
data_var: str
name of datavariable to plot from result
method: str
interpolation method.
Options: `'nearest', 'linear', 'cubic'`
See [`scipy.interpolate.griddata`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.griddata.html)
resolution: complex
Change the amount the input is interpolated.
For more interpolation set higher (200j is more than 100j)
Returns
-------
fig, ax: tuple
matplotlib figure and axes of plot
"""

grid_values, lat, lon = interpolate_analysis(
result=result, data_var=data_var, method=method
result=result, data_var=data_var, method=method, resolution=resolution
)

fig = plt.figure()
fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1], projection=ccrs.LambertConformal(), frameon=False)
ax.patch.set_visible(False)

Expand All @@ -977,7 +1010,7 @@ def plot_sparse_analysis(result: xr.Dataset, data_var: str, method="nearest") ->
origin="lower",
cmap="viridis",
transform=ccrs.PlateCarree(),
) # should this be trnsposed
)

shapename = "admin_1_states_provinces_lakes"
states_shp = shpreader.natural_earth(
Expand All @@ -994,7 +1027,9 @@ def plot_sparse_analysis(result: xr.Dataset, data_var: str, method="nearest") ->
cbar = plt.colorbar(img, ax=ax, orientation="vertical", fraction=0.02, pad=0.04)
cbar.set_label("Value")

plt.title("Interpolated Heatmap")
plt.title(f"Interpolated Sparse Analysis, {data_var}")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.show()

return fig, ax
46 changes: 46 additions & 0 deletions pvdeg/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,3 +1262,49 @@ def compare_templates(
return False

return True

def merge_sparse(files: list[str])->xr.Dataset:
"""
Merge an arbitrary number of geospatial analysis results.
Creates monotonically increasing indicies.
Uses `engine='h5netcdf'` for reliability, use h5netcdf to save your results to netcdf files.
Parameters
-----------
files: list[str]
A list of strings representing filepaths to netcdf (.nc) files.
Each file must have the same coordinates, `['latitude','longitude']` and identical datavariables.
Returns
-------
merged_ds: xr.Dataset
Dataset (in memory) with `coordinates = ['latitude','longitude']` and datavariables matching files in
filepaths list.
"""

datasets = [xr.open_dataset(fp, engine='h5netcdf').compute() for fp in files]

latitudes = np.concatenate([ds.latitude.values for ds in datasets])
longitudes = np.concatenate([ds.longitude.values for ds in datasets])
unique_latitudes = np.sort(np.unique(latitudes))
unique_longitudes = np.sort(np.unique(longitudes))

data_vars = datasets[0].data_vars

merged_ds = xr.Dataset(
{var: (['latitude', 'longitude'], np.full((len(unique_latitudes), len(unique_longitudes)), np.nan)) for var in data_vars},
coords={
'latitude': unique_latitudes,
'longitude': unique_longitudes
}
)

for ds in datasets:
lat_inds = np.searchsorted(unique_latitudes, ds.latitude.values)
lon_inds = np.searchsorted(unique_longitudes, ds.longitude.values)

for var in ds.data_vars:
merged_ds[var].values[np.ix_(lat_inds, lon_inds)] = ds[var].values

return merged_ds

0 comments on commit 1b113c3

Please sign in to comment.