diff --git a/pygmtsar/pygmtsar/AWS.py b/pygmtsar/pygmtsar/AWS.py index 75ad5284..2a443398 100644 --- a/pygmtsar/pygmtsar/AWS.py +++ b/pygmtsar/pygmtsar/AWS.py @@ -25,6 +25,7 @@ def download_dem(self, geometry, filename=None, n_jobs=-1, product='1s', skip_ex dem.plot.imshow() """ import xarray as xr + import rioxarray as rio import geopandas as gpd import numpy as np from tqdm.auto import tqdm @@ -53,7 +54,8 @@ def job_tile(product, lon, lat): if response.status_code != 200: return None with io.BytesIO(response.content) as f: - tile = xr.open_dataarray(f, engine='rasterio')\ + #tile = xr.open_dataarray(f, engine='rasterio') + tile = rio.open_rasterio(f, chunks=self.chunksize)\ .squeeze(drop=True)\ .rename({'y': 'lat', 'x': 'lon'})\ .drop_vars('spatial_ref')\ @@ -69,7 +71,7 @@ def job_tile(product, lon, lat): tile_xarrays = joblib.Parallel(n_jobs=n_jobs)(joblib.delayed(job_tile)(product, x, y)\ for x in range(left, right + 1) for y in range(lower, upper + 1)) - dem = xr.combine_by_coords([tile for tile in tile_xarrays if tile is not None])['band_data'] + dem = xr.combine_by_coords([tile for tile in tile_xarrays if tile is not None]) bounds = self.get_bounds(geometry) dem = dem.sel(lat=slice(bounds[1], bounds[3]), lon=slice(bounds[0], bounds[2]))