From 0657ecd87038858c18df827c177872ebc64851e0 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Mon, 4 Nov 2024 17:02:06 -0500 Subject: [PATCH] use rasterio bounds() and from_bounds() to transform Box's remove usage of Box.normalize() since that should now be unnecessary --- rastervision_core/rastervision/core/box.py | 2 +- .../data/crs_transformer/crs_transformer.py | 104 +++++++++++++----- .../identity_crs_transformer.py | 4 +- .../rasterio_crs_transformer.py | 80 +++++++++++--- .../data/raster_source/multi_raster_source.py | 2 +- .../core/data/raster_source/raster_source.py | 2 +- .../core/data/raster_source/xarray_source.py | 2 +- .../rastervision/core/data/utils/rasterio.py | 2 +- .../test_rasterio_crs_transformer.py | 4 +- tests/core/data/mock_crs_transformer.py | 4 +- .../raster_source/test_multi_raster_source.py | 8 +- tests/core/data/utils/test_rasterio.py | 6 +- 12 files changed, 159 insertions(+), 61 deletions(-) diff --git a/rastervision_core/rastervision/core/box.py b/rastervision_core/rastervision/core/box.py index 3521a5716..47972e99d 100644 --- a/rastervision_core/rastervision/core/box.py +++ b/rastervision_core/rastervision/core/box.py @@ -271,7 +271,7 @@ def to_shapely(self) -> Polygon: def to_rasterio(self) -> RioWindow: """Convert to a Rasterio Window.""" - return RioWindow.from_slices(*self.normalize().to_slices()) + return RioWindow.from_slices(*self.to_slices()) def to_slices(self, h_step: int | None = None, w_step: int | None = None) -> tuple[slice, slice]: diff --git a/rastervision_core/rastervision/core/data/crs_transformer/crs_transformer.py b/rastervision_core/rastervision/core/data/crs_transformer/crs_transformer.py index a64163e5b..5bc9756e4 100644 --- a/rastervision_core/rastervision/core/data/crs_transformer/crs_transformer.py +++ b/rastervision_core/rastervision/core/data/crs_transformer/crs_transformer.py @@ -59,24 +59,18 @@ def map_to_pixel(self, inp, bbox: Box | None = None): Coordinate-transformed input in the same format. """ if isinstance(inp, Box): - box_in = inp - ymin, xmin, ymax, xmax = box_in - xmin_tf, ymin_tf = self._map_to_pixel((xmin, ymin)) - xmax_tf, ymax_tf = self._map_to_pixel((xmax, ymax)) - box_out = Box(ymin_tf, xmin_tf, ymax_tf, xmax_tf) + box_out = self._map_to_pixel_box(inp) if bbox is not None: box_out = box_out.to_local_coords(bbox) return box_out elif isinstance(inp, BaseGeometry): - geom_in = inp - geom_out = transform( - lambda x, y, z=None: self._map_to_pixel((x, y)), geom_in) + geom_out = self._map_to_pixel_geom(inp) if bbox is not None: xmin, ymin = bbox.xmin, bbox.ymin geom_out = translate(geom_out, xoff=-xmin, yoff=-ymin) return geom_out elif len(inp) == 2: - out = self._map_to_pixel(inp) + out = self._map_to_pixel_point(inp) out_x, out_y = out out = (np.array(out_x), np.array(out_y)) if bbox is not None: @@ -128,25 +122,21 @@ def pixel_to_map(self, inp, bbox: Box | None = None): box_in = inp if bbox is not None: box_in = box_in.to_global_coords(bbox) - ymin, xmin, ymax, xmax = box_in - xmin_tf, ymin_tf = self._pixel_to_map((xmin, ymin)) - xmax_tf, ymax_tf = self._pixel_to_map((xmax, ymax)) - box_out = Box(ymin_tf, xmin_tf, ymax_tf, xmax_tf) + box_out = self._pixel_to_map_box(box_in) return box_out elif isinstance(inp, BaseGeometry): geom_in = inp if bbox is not None: xmin, ymin = bbox.xmin, bbox.ymin geom_in = translate(geom_in, xoff=xmin, yoff=ymin) - geom_out = transform( - lambda x, y, z=None: self._pixel_to_map((x, y)), geom_in) + geom_out = self._pixel_to_map_geom(geom_in) return geom_out elif len(inp) == 2: if bbox is not None: xmin, ymin = bbox.xmin, bbox.ymin inp_x, inp_y = inp inp = (inp_x + xmin, inp_y + ymin) - out = self._pixel_to_map(inp) + out = self._pixel_to_map_point(inp) out_x, out_y = out out = (np.array(out_x), np.array(out_y)) return out @@ -155,25 +145,87 @@ def pixel_to_map(self, inp, bbox: Box | None = None): 'Input must be 2-tuple or Box or shapely geometry.') @abstractmethod - def _map_to_pixel(self, point: tuple[float, float]) -> tuple[int, int]: - """Transform point from map to pixel coordinates. + def _map_to_pixel_point(self, + point: tuple[float, float]) -> tuple[int, int]: + """Transform point(s) from map to pixel coordinates. + + Args: + map_point: ``(x, y)`` tuple in map coordinates (eg. lon/lat). ``x`` + and ``y`` can be single values or array-like. + + Returns: + ``(x, y)`` tuple in pixel coordinates. + """ + + def _map_to_pixel_box(self, box: Box) -> Box: + """Transform a :class:`Box` from map to pixel coordinates. + + Args: + box: Box to transform. + + Returns: + Box in pixel coordinates. + """ + ymin, xmin, ymax, xmax = box + xmin_tf, ymin_tf = self._map_to_pixel_point((xmin, ymin)) + xmax_tf, ymax_tf = self._map_to_pixel_point((xmax, ymax)) + pixel_box = Box(ymin_tf, xmin_tf, ymax_tf, xmax_tf) + return pixel_box + + def _map_to_pixel_geom(self, geom: Box) -> Box: + """Transform a shapely geom from map to pixel coordinates. Args: - map_point: (x, y) tuple in map coordinates (eg. lon/lat). x and y - can be single values or array-like. + geom: Geom to transform. Returns: - tuple[int, int]: (x, y) tuple in pixel coordinates. + Geom in pixel coordinates. """ + pixel_geom = transform( + lambda x, y, z=None: self._map_to_pixel_point((x, y)), + geom, + ) + return pixel_geom @abstractmethod - def _pixel_to_map(self, point: tuple[int, int]) -> tuple[float, float]: - """Transform point from pixel to map coordinates. + def _pixel_to_map_point(self, + point: tuple[int, int]) -> tuple[float, float]: + """Transform point(s) from pixel to map coordinates. + + Args: + pixel_point: ``(x, y)`` tuple in pixel coordinates. ``x`` and ``y`` + can be single values or array-like. + + Returns: + ``(x, y)`` tuple in map coordinates (eg. lon/lat). + """ + + def _pixel_to_map_box(self, box: Box) -> Box: + """Transform a :class:`Box` from pixel to map coordinates. + + Args: + box: Box to transform. + + Returns: + Box in map coordinates (eg. lon/lat). + """ + ymin, xmin, ymax, xmax = box + xmin_tf, ymin_tf = self._pixel_to_map_point((xmin, ymin)) + xmax_tf, ymax_tf = self._pixel_to_map_point((xmax, ymax)) + map_box = Box(ymin_tf, xmin_tf, ymax_tf, xmax_tf) + return map_box + + def _pixel_to_map_geom(self, geom: Box) -> Box: + """Transform a shapely geom from pixel to map coordinates. Args: - pixel_point: (x, y) tuple in pixel coordinates. x and y can be - single values or array-like. + geom: Geom to transform. Returns: - tuple[float, float]: (x, y) tuple in map coordinates (eg. lon/lat). + Geom in map coordinates. """ + map_geom = transform( + lambda x, y, z=None: self._pixel_to_map_point((x, y)), + geom, + ) + return map_geom diff --git a/rastervision_core/rastervision/core/data/crs_transformer/identity_crs_transformer.py b/rastervision_core/rastervision/core/data/crs_transformer/identity_crs_transformer.py index 9bcc148ee..f90dc6c48 100644 --- a/rastervision_core/rastervision/core/data/crs_transformer/identity_crs_transformer.py +++ b/rastervision_core/rastervision/core/data/crs_transformer/identity_crs_transformer.py @@ -7,7 +7,7 @@ class IdentityCRSTransformer(CRSTransformer): This is useful for non-georeferenced imagery. """ - def _map_to_pixel(self, map_point): + def _map_to_pixel_point(self, map_point): """Identity function. Args: @@ -18,7 +18,7 @@ def _map_to_pixel(self, map_point): """ return map_point - def _pixel_to_map(self, pixel_point): + def _pixel_to_map_point(self, pixel_point): """Identity function. Args: diff --git a/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py b/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py index ae618840d..cacb2b84c 100644 --- a/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py +++ b/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py @@ -6,8 +6,10 @@ import numpy as np import rasterio as rio from rasterio.transform import (rowcol, xy) +from rasterio.windows import bounds, from_bounds, Window from rasterio import Affine +from rastervision.core.box import Box from rastervision.core.data.crs_transformer import (CRSTransformer, IdentityCRSTransformer) @@ -16,6 +18,9 @@ log = logging.getLogger(__name__) +PIXEL_PRECISION = 6 +MAP_PRECISION = 6 + def pyproj_wrapper( func: Callable[..., tuple[Any, Any]], @@ -54,18 +59,18 @@ class RasterioCRSTransformer(CRSTransformer): def __init__(self, transform: Affine, - image_crs: Any, - map_crs: Any = 'epsg:4326', + image_crs: str, + map_crs: str = 'epsg:4326', round_pixels: bool = True): """Constructor. Args: - transform (Affine): Rasterio affine transform. - image_crs (Any): CRS of image in format that PyProj can handle - eg. wkt or init string. - map_crs (Any): CRS of the labels. Defaults to "epsg:4326". - round_pixels (bool): If True, round outputs of map_to_pixel and - inputs of pixel_to_map to integers. Defaults to False. + transform: Rasterio affine transform. + image_crs: CRS of image in format that PyProj can handle eg. WKT or + init string. + map_crs: CRS of the labels. Defaults to "epsg:4326". + round_pixels: If ``True``, round outputs of :meth:`.map_to_pixel`. + Defaults to ``False``. """ if (image_crs is None) or (image_crs == map_crs): @@ -106,7 +111,7 @@ def __repr__(self) -> str: """ return out - def _map_to_pixel( + def _map_to_pixel_point( self, map_point: tuple[float, float] | tuple[np.ndarray, np.ndarray] ) -> tuple[int, int] | tuple[np.ndarray, np.ndarray]: @@ -120,14 +125,42 @@ def _map_to_pixel( """ image_point = self.map2image(*map_point) x, y = image_point + row, col = rowcol( + self.transform, x, y, op=lambda x: np.round(x, PIXEL_PRECISION)) if self.round_pixels: - row, col = rowcol(self.transform, x, y) - else: - row, col = rowcol(self.transform, x, y, op=lambda x: x) + row, col = np.round(row), np.round(col) pixel_point = (col, row) return pixel_point - def _pixel_to_map( + def _map_to_pixel_box(self, box: Box) -> Box: + ymin, xmin, ymax, xmax = box + xmin, ymin = self.map2image(xmin, ymin) + xmax, ymax = self.map2image(xmax, ymax) + rio_window = from_bounds( + left=xmin, + bottom=ymin, + right=xmax, + top=ymax, + transform=self.transform, + ) + (ymin, ymax), (xmin, xmax) = rio_window.toranges() + ymin, xmin, ymax, xmax = ( + round(ymin, PIXEL_PRECISION), + round(xmin, PIXEL_PRECISION), + round(ymax, PIXEL_PRECISION), + round(xmax, PIXEL_PRECISION), + ) + if self.round_pixels: + ymin, xmin, ymax, xmax = ( + round(ymin), + round(xmin), + round(ymax), + round(xmax), + ) + pixel_box = Box(ymin, xmin, ymax, xmax) + return pixel_box + + def _pixel_to_map_point( self, pixel_point: tuple[int, int] | tuple[np.ndarray, np.ndarray] ) -> tuple[float, float] | tuple[np.ndarray, np.ndarray]: """Transform point from pixel to map-based coordinates. @@ -139,13 +172,24 @@ def _pixel_to_map( (x, y) tuple in map coordinates """ col, row = pixel_point - if self.round_pixels: - col = col.astype(int) if isinstance(col, np.ndarray) else int(col) - row = row.astype(int) if isinstance(row, np.ndarray) else int(row) - image_point = xy(self.transform, row, col, offset='center') - map_point = self.image2map(*image_point) + x, y = xy(self.transform, row, col, offset='center') + map_point = self.image2map(x, y) return map_point + def _pixel_to_map_box(self, box: Box) -> Box: + rio_window = Window(*box.to_xywh()) + xmin, ymin, xmax, ymax = bounds(rio_window, transform=self.transform) + xmin, ymin, xmax, ymax = ( + round(xmin, PIXEL_PRECISION), + round(ymin, PIXEL_PRECISION), + round(xmax, PIXEL_PRECISION), + round(ymax, PIXEL_PRECISION), + ) + xmin, ymin = self.image2map(xmin, ymin) + xmax, ymax = self.image2map(xmax, ymax) + map_box = Box(ymin, xmin, ymax, xmax) + return map_box + @classmethod def from_dataset(cls, dataset: Any, diff --git a/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py b/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py index 124b1d64d..2e1dca662 100644 --- a/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py @@ -132,7 +132,7 @@ def from_stac( crs_transformer = raster_sources[primary_source_idx].crs_transformer if bbox_map_coords is not None: bbox_map_coords = Box(*bbox_map_coords) - bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize() + bbox = crs_transformer.map_to_pixel(bbox_map_coords) elif bbox is not None: bbox = Box(*bbox) diff --git a/rastervision_core/rastervision/core/data/raster_source/raster_source.py b/rastervision_core/rastervision/core/data/raster_source/raster_source.py index 4e7e4be37..1129472db 100644 --- a/rastervision_core/rastervision/core/data/raster_source/raster_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/raster_source.py @@ -166,7 +166,7 @@ def get_chip_by_map_window(self, window_map_coords: 'Box', *args, **kwargs) -> 'np.ndarray': """Same as get_chip(), but input is a window in map coords.""" window_pixel_coords = self.crs_transformer.map_to_pixel( - window_map_coords, bbox=self.bbox).normalize() + window_map_coords, bbox=self.bbox) chip = self.get_chip(window_pixel_coords, *args, **kwargs) return chip diff --git a/rastervision_core/rastervision/core/data/raster_source/xarray_source.py b/rastervision_core/rastervision/core/data/raster_source/xarray_source.py index 437cc350b..88c59e465 100644 --- a/rastervision_core/rastervision/core/data/raster_source/xarray_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/xarray_source.py @@ -147,7 +147,7 @@ def from_stac( bbox = Box(*bbox) elif bbox_map_coords is not None: bbox_map_coords = Box(*bbox_map_coords) - bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize() + bbox = crs_transformer.map_to_pixel(bbox_map_coords) else: bbox = None diff --git a/rastervision_core/rastervision/core/data/utils/rasterio.py b/rastervision_core/rastervision/core/data/utils/rasterio.py index 4fb31f32a..dc53d4010 100644 --- a/rastervision_core/rastervision/core/data/utils/rasterio.py +++ b/rastervision_core/rastervision/core/data/utils/rasterio.py @@ -100,7 +100,7 @@ def write_geotiff_like_geojson(path: str, crs = 'epsg:4326' crs_wkt = pyproj.CRS(crs).to_wkt() geoms = unary_union(list(geojson_to_geoms(geojson))) - bbox = Box.from_shapely(geoms).normalize() + bbox = Box.from_shapely(geoms) write_bbox(path, arr, bbox=bbox, crs_wkt=crs_wkt, **kwargs) diff --git a/tests/core/data/crs_transformer/test_rasterio_crs_transformer.py b/tests/core/data/crs_transformer/test_rasterio_crs_transformer.py index fe7eba234..14de8ff52 100644 --- a/tests/core/data/crs_transformer/test_rasterio_crs_transformer.py +++ b/tests/core/data/crs_transformer/test_rasterio_crs_transformer.py @@ -14,8 +14,8 @@ def setUp(self): self.im_path = data_file_path('3857.tif') self.im_dataset = rasterio.open(self.im_path) self.crs_trans = RasterioCRSTransformer.from_dataset(self.im_dataset) - self.lon_lat = (-115.3063715, 36.1268253) - self.pix_point = (50, 61) + self.lon_lat = (-115.306372, 36.126825) + self.pix_point = (51, 62) def test_map_to_pixel_point(self): # w/o bbox diff --git a/tests/core/data/mock_crs_transformer.py b/tests/core/data/mock_crs_transformer.py index 42c9a9d81..cebaf30f0 100644 --- a/tests/core/data/mock_crs_transformer.py +++ b/tests/core/data/mock_crs_transformer.py @@ -7,8 +7,8 @@ class DoubleCRSTransformer(CRSTransformer): Assumes pixel coords are 2x map coords. """ - def _map_to_pixel(self, map_point): + def _map_to_pixel_point(self, map_point): return (map_point[0] * 2.0, map_point[1] * 2.0) - def _pixel_to_map(self, pixel_point): + def _pixel_to_map_point(self, pixel_point): return (pixel_point[0] / 2.0, pixel_point[1] / 2.0) diff --git a/tests/core/data/raster_source/test_multi_raster_source.py b/tests/core/data/raster_source/test_multi_raster_source.py index b8dea2763..77ad87be1 100644 --- a/tests/core/data/raster_source/test_multi_raster_source.py +++ b/tests/core/data/raster_source/test_multi_raster_source.py @@ -244,10 +244,14 @@ def test_from_stac(self): # test bbox_map_coords bbox_map_coords = Box( - ymin=29.978710, xmin=31.134949, ymax=29.977309, xmax=31.136567) + ymin=29.977309, + xmin=31.134949, + ymax=29.978710, + xmax=31.136567, + ) rs = MultiRasterSource.from_stac( item, assets=['red', 'green'], bbox_map_coords=bbox_map_coords) - self.assertEqual(rs.bbox, Box(ymin=50, xmin=50, ymax=206, xmax=206)) + self.assertEqual(rs.bbox, Box(ymin=51, xmin=50, ymax=207, xmax=206)) # test error if both bbox and bbox_map_coords specified args = dict( diff --git a/tests/core/data/utils/test_rasterio.py b/tests/core/data/utils/test_rasterio.py index 100514b95..f395497c1 100644 --- a/tests/core/data/utils/test_rasterio.py +++ b/tests/core/data/utils/test_rasterio.py @@ -24,16 +24,14 @@ def test_write_bbox(self): geotiff_path = join(tmp_dir, 'test.geotiff') write_bbox(geotiff_path, arr1, bbox=bbox, crs_wkt=crs_wkt) rs = RasterioSource(geotiff_path) - geotiff_bbox = rs.crs_transformer.pixel_to_map( - rs.extent).normalize() + geotiff_bbox = rs.crs_transformer.pixel_to_map(rs.extent) np.testing.assert_array_almost_equal( np.array(list(geotiff_bbox)), np.array(list(bbox)), decimal=3) self.assertEqual(rs.shape, (*arr1.shape, 1)) write_bbox(geotiff_path, arr2, bbox=bbox, crs_wkt=crs_wkt) rs = RasterioSource(geotiff_path) - geotiff_bbox = rs.crs_transformer.pixel_to_map( - rs.extent).normalize() + geotiff_bbox = rs.crs_transformer.pixel_to_map(rs.extent) np.testing.assert_array_almost_equal( np.array(list(geotiff_bbox)), np.array(list(bbox)), decimal=3) self.assertEqual(rs.shape, arr2.shape)