Skip to content

Commit

Permalink
Add splitting utilities for GeoDatasets (#866)
Browse files Browse the repository at this point in the history
* add extent_crop to BoundingBox

* add extent_crop param to RasterDataset

* train_test_split function

* minor changes

* fix circular import

* remove extent_crop

* move existing functions to new file

* refactor random_nongeo_split

* refactor random_bbox_splitting

* add roi_split

* add random_bbox_assignment

* add input checks

* fix input type

* minor reorder

* add tests

* add non-overlapping test

* more tests

* fix tests

* additional tests

* check overlapping rois

* add time_series_split with tests

* fix random_nongeo_split to work with fractions in torch 1.9

* modify random_nongeo_split test for coverage

* add random_grid_cell_assignment with tests

* add test

* insert object into new indexes

* check grid_size

* better tests

* small type fix

* fix again

* rm .DS_Store

* fix typo

Co-authored-by: Adam J. Stewart <[email protected]>

* bump version added

* add to __init__

* add to datasets.rst

* use accumulate from itertools

* clarify grid_size

* remove random_nongeo_split

* remove _create_geodataset_like

* black reformatting

* Update tests/datasets/test_splits.py

Co-authored-by: Adam J. Stewart <[email protected]>

* change import

* docstrings

* undo intersection change

* use microsecond

* use isclose

* black

* fix typing

* add comments

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
pmandiola and adamjstewart authored Feb 21, 2023
1 parent 4a92cf4 commit ceeec81
Show file tree
Hide file tree
Showing 6 changed files with 784 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,12 @@ Collation Functions
.. autofunction:: concat_samples
.. autofunction:: merge_samples
.. autofunction:: unbind_samples

Splitting Functions
^^^^^^^^^^^^^^^^^^^

.. autofunction:: random_bbox_assignment
.. autofunction:: random_bbox_splitting
.. autofunction:: random_grid_cell_assignment
.. autofunction:: roi_split
.. autofunction:: time_series_split
322 changes: 322 additions & 0 deletions tests/datasets/test_splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from math import floor, isclose
from typing import Any, Dict, List, Sequence, Tuple, Union

import pytest
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
GeoDataset,
random_bbox_assignment,
random_bbox_splitting,
random_grid_cell_assignment,
roi_split,
time_series_split,
)


class CustomGeoDataset(GeoDataset):
def __init__(
self,
items: List[Tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), "")],
crs: CRS = CRS.from_epsg(3005),
res: float = 1,
) -> None:
super().__init__()
for box, content in items:
self.index.insert(0, tuple(box), content)
self._crs = crs
self.res = res

def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
hits = self.index.intersection(tuple(query), objects=True)
hit = next(iter(hits))
return {"content": hit.object}


@pytest.mark.parametrize(
"lengths,expected_lengths",
[
# List of lengths
([2, 1, 1], [2, 1, 1]),
# List of fractions (with remainder)
([1 / 3, 1 / 3, 1 / 3], [2, 1, 1]),
],
)
def test_random_bbox_assignment(
lengths: Sequence[Union[int, float]], expected_lengths: Sequence[int]
) -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 0), "a"),
(BoundingBox(1, 2, 0, 1, 0, 0), "b"),
(BoundingBox(2, 3, 0, 1, 0, 0), "c"),
(BoundingBox(3, 4, 0, 1, 0, 0), "d"),
]
)

train_ds, val_ds, test_ds = random_bbox_assignment(ds, lengths)

# Check datasets lengths
assert len(train_ds) == expected_lengths[0]
assert len(val_ds) == expected_lengths[1]
assert len(test_ds) == expected_lengths[2]

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds

# Test __getitem__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)


def test_random_bbox_assignment_invalid_inputs() -> None:
with pytest.raises(
ValueError,
match="Sum of input lengths must equal 1 or the length of dataset's index.",
):
random_bbox_assignment(CustomGeoDataset(), lengths=[2, 2, 1])
with pytest.raises(
ValueError, match="All items in input lengths must be greater than 0."
):
random_bbox_assignment(CustomGeoDataset(), lengths=[1 / 2, 3 / 4, -1 / 4])


def _get_total_area(dataset: GeoDataset) -> float:
total_area = 0.0
for hit in dataset.index.intersection(dataset.index.bounds, objects=True):
total_area += BoundingBox(*hit.bounds).area

return total_area


def test_random_bbox_splitting() -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 0), "a"),
(BoundingBox(1, 2, 0, 1, 0, 0), "b"),
(BoundingBox(2, 3, 0, 1, 0, 0), "c"),
(BoundingBox(3, 4, 0, 1, 0, 0), "d"),
]
)

ds_area = _get_total_area(ds)

train_ds, val_ds, test_ds = random_bbox_splitting(
ds, fractions=[1 / 2, 1 / 4, 1 / 4]
)
train_ds_area = _get_total_area(train_ds)
val_ds_area = _get_total_area(val_ds)
test_ds_area = _get_total_area(test_ds)

# Check datasets areas
assert train_ds_area == ds_area / 2
assert val_ds_area == ds_area / 4
assert test_ds_area == ds_area / 4

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area)

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)

# Test invalid input fractions
with pytest.raises(ValueError, match="Sum of input fractions must equal 1."):
random_bbox_splitting(ds, fractions=[1 / 2, 1 / 3, 1 / 4])
with pytest.raises(
ValueError, match="All items in input fractions must be greater than 0."
):
random_bbox_splitting(ds, fractions=[1 / 2, 3 / 4, -1 / 4])


def test_random_grid_cell_assignment() -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 12, 0, 12, 0, 0), "a"),
(BoundingBox(12, 24, 0, 12, 0, 0), "b"),
]
)

train_ds, val_ds, test_ds = random_grid_cell_assignment(
ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=5
)

# Check datasets lengths
assert len(train_ds) == 1 / 2 * 2 * 5**2 + 1
assert len(val_ds) == floor(1 / 4 * 2 * 5**2)
assert len(test_ds) == floor(1 / 4 * 2 * 5**2)

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)

# Test invalid input fractions
with pytest.raises(ValueError, match="Sum of input fractions must equal 1."):
random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 3, 1 / 4])
with pytest.raises(
ValueError, match="All items in input fractions must be greater than 0."
):
random_grid_cell_assignment(ds, fractions=[1 / 2, 3 / 4, -1 / 4])
with pytest.raises(ValueError, match="Input grid_size must be greater than 1."):
random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=1)


def test_roi_split() -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 0), "a"),
(BoundingBox(1, 2, 0, 1, 0, 0), "b"),
(BoundingBox(2, 3, 0, 1, 0, 0), "c"),
(BoundingBox(3, 4, 0, 1, 0, 0), "d"),
]
)

train_ds, val_ds, test_ds = roi_split(
ds,
rois=[
BoundingBox(0, 2, 0, 1, 0, 0),
BoundingBox(2, 3.5, 0, 1, 0, 0),
BoundingBox(3.5, 4, 0, 1, 0, 0),
],
)

# Check datasets lengths
assert len(train_ds) == 2
assert len(val_ds) == 2
assert len(test_ds) == 1

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)

# Test invalid input rois
with pytest.raises(ValueError, match="ROIs in input rois can't overlap."):
roi_split(
ds, rois=[BoundingBox(0, 2, 0, 1, 0, 0), BoundingBox(1, 3, 0, 1, 0, 0)]
)


@pytest.mark.parametrize(
"lengths,expected_lengths",
[
# List of timestamps
([(0, 20), (20, 35), (35, 40)], [2, 2, 1]),
# List of lengths
([20, 15, 5], [2, 2, 1]),
# List of fractions (with remainder)
([1 / 2, 3 / 8, 1 / 8], [2, 2, 1]),
],
)
def test_time_series_split(
lengths: Sequence[Union[Tuple[int, int], int, float]],
expected_lengths: Sequence[int],
) -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 10), "a"),
(BoundingBox(0, 1, 0, 1, 10, 20), "b"),
(BoundingBox(0, 1, 0, 1, 20, 30), "c"),
(BoundingBox(0, 1, 0, 1, 30, 40), "d"),
]
)

train_ds, val_ds, test_ds = time_series_split(ds, lengths)

# Check datasets lengths
assert len(train_ds) == expected_lengths[0]
assert len(val_ds) == expected_lengths[1]
assert len(test_ds) == expected_lengths[2]

# No overlap
assert len(train_ds & val_ds) == 0
assert len(val_ds & test_ds) == 0
assert len(test_ds & train_ds) == 0

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)


def test_time_series_split_invalid_input() -> None:
with pytest.raises(
ValueError,
match="Pairs of timestamps in lengths must have end greater than start.",
):
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (35, 20), (35, 40)])

with pytest.raises(
ValueError,
match="Pairs of timestamps in lengths must cover dataset's time bounds.",
):
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 35)])

with pytest.raises(
ValueError,
match="Pairs of timestamps in lengths can't be out of dataset's time bounds.",
):
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 45)])

with pytest.raises(
ValueError, match="Pairs of timestamps in lengths can't overlap."
):
time_series_split(CustomGeoDataset(), lengths=[(0, 10), (10, 20), (15, 40)])

with pytest.raises(
ValueError,
match="Sum of input lengths must equal 1 or the dataset's time length.",
):
time_series_split(CustomGeoDataset(), lengths=[1 / 2, 1 / 2, 1 / 2])

with pytest.raises(
ValueError, match="All items in input lengths must be greater than 0."
):
time_series_split(CustomGeoDataset(), lengths=[20, 25, -5])
29 changes: 29 additions & 0 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,35 @@ def test_intersects(
bbox2 = BoundingBox(*test_input)
assert bbox1.intersects(bbox2) == bbox2.intersects(bbox1) == expected

@pytest.mark.parametrize(
"proportion,horizontal,expected",
[
(0.25, True, ((0, 0.25, 0, 1, 0, 1), (0.25, 1, 0, 1, 0, 1))),
(0.25, False, ((0, 1, 0, 0.25, 0, 1), (0, 1, 0.25, 1, 0, 1))),
],
)
def test_split(
self,
proportion: float,
horizontal: bool,
expected: Tuple[
Tuple[float, float, float, float, float, float],
Tuple[float, float, float, float, float, float],
],
) -> None:
bbox = BoundingBox(0, 1, 0, 1, 0, 1)
bbox1, bbox2 = bbox.split(proportion, horizontal)
assert bbox1 == BoundingBox(*expected[0])
assert bbox2 == BoundingBox(*expected[1])
assert bbox1 | bbox2 == bbox

def test_split_error(self) -> None:
bbox = BoundingBox(0, 1, 0, 1, 0, 1)
with pytest.raises(
ValueError, match="Input proportion must be between 0 and 1."
):
bbox.split(1.5)

def test_picklable(self) -> None:
bbox = BoundingBox(0, 1, 2, 3, 4, 5)
x = pickle.dumps(bbox)
Expand Down
13 changes: 13 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@
SpaceNet6,
SpaceNet7,
)
from .splits import (
random_bbox_assignment,
random_bbox_splitting,
random_grid_cell_assignment,
roi_split,
time_series_split,
)
from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
Expand Down Expand Up @@ -207,4 +214,10 @@
"merge_samples",
"stack_samples",
"unbind_samples",
# Splits
"random_bbox_assignment",
"random_bbox_splitting",
"random_grid_cell_assignment",
"roi_split",
"time_series_split",
)
Loading

0 comments on commit ceeec81

Please sign in to comment.