Skip to content

Commit

Permalink
Python: Warn on CRS mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Aug 28, 2024
1 parent 0200bf6 commit 8bec482
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 5 deletions.
56 changes: 55 additions & 1 deletion python/src/exactextract/exact_extract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import functools
import os
import warnings
from itertools import chain
from typing import Mapping, Optional

Expand Down Expand Up @@ -253,6 +255,57 @@ def prep_writer(output, srs_wkt, options):
raise Exception("Unsupported value of output")


@functools.cache
def crs_matches(a, b):
if a.srs_wkt() is None or b.srs_wkt() is None:
return True

if a.srs_wkt() == b.srs_wkt():
return True

try:
from osgeo import osr

srs_a = osr.SpatialReference()
srs_b = osr.SpatialReference()

srs_a.ImportFromWkt(a.srs_wkt())
srs_b.ImportFromWkt(b.srs_wkt())

srs_a.StripVertical()
srs_b.StripVertical()

return srs_a.IsSame(srs_b)

except ImportError:
return False


def warn_on_crs_mismatch(vec, ops):

check_rast_vec = True
check_rast_weights = True

for op in ops:
if check_rast_vec and not crs_matches(vec, op.values):
check_rast_vec = False
warnings.warn(
"Spatial reference system of input features does not exactly match raster.",
RuntimeWarning,
)

if (
check_rast_weights
and op.weights is not None
and not crs_matches(vec, op.weights)
):
check_rast_weights = False
warnings.warn(
"Spatial reference system of input features does not exactly match weighting raster.",
RuntimeWarning,
)


def exact_extract(
rast,
vec,
Expand Down Expand Up @@ -321,7 +374,6 @@ def exact_extract(
rast = prep_raster(rast)
weights = prep_raster(weights, name_root="weight")
vec = prep_vec(vec)
# TODO: check CRS and transform if necessary/possible?

if output_options is None:
output_options = {}
Expand All @@ -330,6 +382,8 @@ def exact_extract(
ops, rast, weights, add_unique=output_options.get("frac_as_map", False)
)

warn_on_crs_mismatch(vec, ops)

if "frac_as_map" in output_options:
output_options = copy.copy(output_options)
del output_options["frac_as_map"]
Expand Down
108 changes: 104 additions & 4 deletions python/tests/test_exact_extract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import contextlib
import math
import os
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -45,6 +47,17 @@ def make_rect(xmin, ymin, xmax, ymax, id=None, properties=None):
return f


@contextlib.contextmanager
def use_gdal_exceptions():
from osgeo import gdal

prev = gdal.GetUseExceptions()
gdal.UseExceptions()
yield
if not prev:
gdal.UseExceptions()


@pytest.mark.parametrize("output_format", ("geojson", "pandas"), indirect=True)
@pytest.mark.parametrize(
"stat,expected",
Expand Down Expand Up @@ -653,10 +666,19 @@ def test_default_weight():


def create_gdal_raster(
fname, values, *, gt=None, gdal_type=None, nodata=None, scale=None, offset=None
fname,
values,
*,
gt=None,
gdal_type=None,
nodata=None,
scale=None,
offset=None,
crs=None,
):
gdal = pytest.importorskip("osgeo.gdal")
gdal_array = pytest.importorskip("osgeo.gdal_array")
osr = pytest.importorskip("osgeo.osr")

drv = gdal.GetDriverByName("GTiff")

Expand All @@ -674,6 +696,14 @@ def create_gdal_raster(
else:
ds.SetGeoTransform(gt)

if crs is not None:
srs = osr.SpatialReference()
if crs.startswith("EPSG"):
srs.ImportFromEPSG(int(crs.strip("EPSG:")))
else:
srs.ImportFromWkt(crs)
ds.SetSpatialRef(srs)

if nodata is not None:
if type(nodata) in {list, tuple}:
for i, v in enumerate(nodata):
Expand Down Expand Up @@ -706,7 +736,7 @@ def create_gdal_raster(
ds.GetRasterBand(1).GetMaskBand().WriteArray(~values.mask)


def create_gdal_features(fname, features, name="test"):
def create_gdal_features(fname, features, name="test", *, crs=None):
gdal = pytest.importorskip("osgeo.gdal")

import json
Expand All @@ -719,8 +749,9 @@ def create_gdal_features(fname, features, name="test"):
)
tf.flush()

ds = gdal.VectorTranslate(str(fname), tf.name)
ds = None # noqa: F841
with use_gdal_exceptions():
ds = gdal.VectorTranslate(str(fname), tf.name, dstSRS=crs)
ds = None # noqa: F841

os.remove(tf.name)

Expand Down Expand Up @@ -1481,3 +1512,72 @@ def test_grid_compat_tol():
exact_extract(
values, square, "weighted_mean", weights=weights, grid_compat_tol=1e-2
)


def test_crs_mismatch(tmp_path):
crs_list = (4269, 4326, None)

rasters = {}
features = {}
for crs in crs_list:
rasters[crs] = tmp_path / f"{crs}.tif"
features[crs] = tmp_path / f"{crs}.shp"
create_gdal_raster(
rasters[crs], np.arange(9).reshape(3, 3), crs=f"EPSG:{crs}" if crs else None
)
create_gdal_features(
features[crs],
[make_rect(0.5, 0.5, 2.5, 2.5)],
crs=f"EPSG:{crs}" if crs else None,
)

with pytest.warns(
RuntimeWarning, match="input features does not exactly match raster"
) as record:
exact_extract(rasters[4326], features[4269], "mean")
assert len(record) == 1

with pytest.warns(
RuntimeWarning, match="input features does not exactly match raster"
) as record:
exact_extract([rasters[4326], rasters[4269]], features[4326], "mean")
assert len(record) == 1

with pytest.warns(
RuntimeWarning, match="input features does not exactly match weighting raster"
) as record:
exact_extract(
rasters[4326], features[4326], "weighted_mean", weights=rasters[4269]
)
assert len(record) == 1

# make sure only a single warning is raised
with pytest.warns(
RuntimeWarning, match="input features does not exactly match raster"
) as record:
exact_extract([rasters[4326], rasters[4269]], features[4326], ["mean", "sum"])
assert len(record) == 1

# any CRS is considered to match an undefined CRS
with warnings.catch_warnings():
warnings.simplefilter("error")
exact_extract(rasters[None], features[4326], "mean")


def test_crs_match_after_normalization(tmp_path):

pytest.importorskip("osgeo.osr")

rast = tmp_path / "test.tif"
square = tmp_path / "test.shp"

rast_crs = 'PROJCS["WGS 84 / UTM zone 4N",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Transverse_Mercator"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",-159],PARAMETER["scale_factor",0.9996],PARAMETER["false_easting",500000],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","32604"]]'

vec_crs = 'PROJCRS["WGS 84 / UTM zone 4N",BASEGEOGCRS["WGS 84",DATUM["World Geodetic System 1984",ELLIPSOID["WGS 84",6378137,298.257223563,LENGTHUNIT["metre",1]]],PRIMEM["Greenwich",0,ANGLEUNIT["degree",0.0174532925199433]],ID["EPSG",4326]],CONVERSION["UTM zone 4N",METHOD["Transverse Mercator",ID["EPSG",9807]],PARAMETER["Latitude of natural origin",0,ANGLEUNIT["degree",0.0174532925199433],ID["EPSG",8801]],PARAMETER["Longitude of natural origin",-159,ANGLEUNIT["degree",0.0174532925199433],ID["EPSG",8802]],PARAMETER["Scale factor at natural origin",0.9996,SCALEUNIT["unity",1],ID["EPSG",8805]],PARAMETER["False easting",500000,LENGTHUNIT["metre",1],ID["EPSG",8806]],PARAMETER["False northing",0,LENGTHUNIT["metre",1],ID["EPSG",8807]]],CS[Cartesian,2],AXIS["easting",east,ORDER[1],LENGTHUNIT["metre",1]],AXIS["northing",north,ORDER[2],LENGTHUNIT["metre",1]],ID["EPSG",32604]]'

create_gdal_raster(rast, np.arange(9).reshape(3, 3), crs=rast_crs)
create_gdal_features(square, [make_rect(0.5, 0.5, 2.5, 2.5)], crs=vec_crs)

with warnings.catch_warnings():
warnings.simplefilter("error")
exact_extract(rast, square, "mean")

0 comments on commit 8bec482

Please sign in to comment.