Skip to content

Commit

Permalink
Python: Add exact_extract support for multivariable rasters with gdal…
Browse files Browse the repository at this point in the history
…, rasterio
  • Loading branch information
dbaston committed Aug 29, 2024
1 parent 399d0cb commit d60c3f2
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 19 deletions.
69 changes: 50 additions & 19 deletions python/src/exactextract/exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,7 @@ def make_raster_names(root: str, nbands: int) -> list:
return [""]


def prep_raster(rast, name_root=None) -> list:
# TODO add some hooks to allow RasterSource implementations
# defined outside this library to handle other input types.
if rast is None:
return [None]

if isinstance(rast, RasterSource):
return [rast]

if type(rast) in (list, tuple):
if all(isinstance(src, (str, os.PathLike)) for src in rast):
sources = [
prep_raster(src, name_root=os.path.splitext(os.path.basename(src))[0])
for src in rast
]
else:
sources = [prep_raster(src) for src in rast]
return list(chain.from_iterable(sources))

def prep_raster_gdal(rast, name_root=None) -> list:
try:
# eagerly import gdal_array to avoid possible ImportError when reading raster data
from osgeo import gdal, gdal_array # noqa: F401
Expand All @@ -65,6 +47,17 @@ def prep_raster(rast, name_root=None) -> list:
rast = gdal.Open(str(rast))

if isinstance(rast, gdal.Dataset):
# Handle inputs such as a netCDF with multiple variables
if rast.RasterCount == 0 and rast.GetSubDatasets():
sources = []
for subds, _ in rast.GetSubDatasets():
try:
varname = gdal.GetSubdatasetInfo(subds).GetSubdatasetComponent()
except AttributeError:
varname = subds.split(":")[-1]
sources.append(prep_raster_gdal(subds, name_root=varname))
return list(chain.from_iterable(sources))

names = make_raster_names(name_root, rast.RasterCount)
return [
GDALRasterSource(rast, i + 1, name=names[i])
Expand All @@ -73,13 +66,23 @@ def prep_raster(rast, name_root=None) -> list:
except ImportError:
pass


def prep_raster_rasterio(rast, name_root=None) -> list:
try:
import rasterio

if isinstance(rast, (str, os.PathLike)):
rast = rasterio.open(rast)

if isinstance(rast, rasterio.DatasetReader):
# Handle inputs such as a netCDF with multiple variables
if rast.count == 0 and rast.subdatasets:
sources = []
for subds in rast.subdatasets:
varname = subds.split(":")[-1]
sources.append(prep_raster_rasterio(subds, name_root=varname))
return list(chain.from_iterable(sources))

names = make_raster_names(name_root, rast.count)
return [
RasterioRasterSource(rast, i + 1, name=names[i])
Expand All @@ -88,6 +91,9 @@ def prep_raster(rast, name_root=None) -> list:
except ImportError:
pass


def prep_raster_xarray(rast, name_root=None) -> list:

try:
import rioxarray # noqa: F401
import xarray
Expand All @@ -102,6 +108,31 @@ def prep_raster(rast, name_root=None) -> list:
except ImportError:
pass


def prep_raster(rast, name_root=None) -> list:
# TODO add some hooks to allow RasterSource implementations
# defined outside this library to handle other input types.
if rast is None:
return [None]

if isinstance(rast, RasterSource):
return [rast]

if type(rast) in (list, tuple):
if all(isinstance(src, (str, os.PathLike)) for src in rast):
sources = [
prep_raster(src, name_root=os.path.splitext(os.path.basename(src))[0])
for src in rast
]
else:
sources = [prep_raster(src) for src in rast]
return list(chain.from_iterable(sources))

for loader in (prep_raster_gdal, prep_raster_rasterio, prep_raster_xarray):
sources = loader(rast, name_root)
if sources:
return sources

raise Exception("Unhandled raster datatype")


Expand Down
77 changes: 77 additions & 0 deletions python/tests/test_exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,3 +1581,80 @@ def test_crs_match_after_normalization(tmp_path):
with warnings.catch_warnings():
warnings.simplefilter("error")
exact_extract(rast, square, "mean")


@pytest.fixture()
def multidim_nc(tmp_path):

gdal = pytest.importorskip("osgeo.gdal")

fname = str(tmp_path / "test_multidim.nc")

nx = 3
ny = 4
nt = 2

drv = gdal.GetDriverByName("netCDF")

ds = drv.CreateMultiDimensional(fname)
rg = ds.GetRootGroup()

x = rg.CreateDimension("longitude", None, None, nx)
y = rg.CreateDimension("latitude", None, None, ny)
t = rg.CreateDimension("time", None, None, nt)

for dim in (x, y, t):
values = np.arange(dim.GetSize()).astype(np.float32)

if dim.GetName() != t.GetName():
values += 0.5
if dim.GetName() == y.GetName():
values = np.flip(values)

arr = rg.CreateMDArray(
dim.GetName(), [dim], gdal.ExtendedDataType.Create(gdal.GDT_Float32)
)
arr.WriteArray(values)
standard_name = arr.CreateAttribute(
"standard_name", [], gdal.ExtendedDataType.CreateString()
)
standard_name.Write(dim.GetName())

t2m = rg.CreateMDArray(
"t2m", [t, y, x], gdal.ExtendedDataType.Create(gdal.GDT_Float32)
)
t2m_data = np.arange(nx * ny * nt).reshape((nt, ny, nx))
t2m.WriteArray(t2m_data)

breakpoint()

tp_data = np.sqrt(t2m_data)
tp = rg.CreateMDArray(
"tp", [t, y, x], gdal.ExtendedDataType.Create(gdal.GDT_Float32)
)
tp.WriteArray(tp_data)

return fname


@pytest.mark.parametrize("libname", ("gdal", "rasterio"))
def test_gdal_multi_variable(multidim_nc, libname):

square = make_rect(0.5, 0.5, 2.5, 2.5)

rast = open_with_lib(multidim_nc, libname)

results = exact_extract(rast, square, ["count", "sum"])

assert results[0]["properties"] == pytest.approx(
{
"tp_band_1_sum": 10.437034457921982,
"tp_band_2_sum": 17.4051194190979,
"t2m_band_1_sum": 28.0,
"tp_band_2_count": 4.0,
"t2m_band_2_sum": 76.0,
"tp_band_1_count": 4.0,
"t2m_band_2_count": 4.0,
"t2m_band_1_count": 4.0,
}
)

0 comments on commit d60c3f2

Please sign in to comment.