-
Notifications
You must be signed in to change notification settings - Fork 380
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add splitting utilities for GeoDatasets (#866)
* add extent_crop to BoundingBox * add extent_crop param to RasterDataset * train_test_split function * minor changes * fix circular import * remove extent_crop * move existing functions to new file * refactor random_nongeo_split * refactor random_bbox_splitting * add roi_split * add random_bbox_assignment * add input checks * fix input type * minor reorder * add tests * add non-overlapping test * more tests * fix tests * additional tests * check overlapping rois * add time_series_split with tests * fix random_nongeo_split to work with fractions in torch 1.9 * modify random_nongeo_split test for coverage * add random_grid_cell_assignment with tests * add test * insert object into new indexes * check grid_size * better tests * small type fix * fix again * rm .DS_Store * fix typo Co-authored-by: Adam J. Stewart <[email protected]> * bump version added * add to __init__ * add to datasets.rst * use accumulate from itertools * clarify grid_size * remove random_nongeo_split * remove _create_geodataset_like * black reformatting * Update tests/datasets/test_splits.py Co-authored-by: Adam J. Stewart <[email protected]> * change import * docstrings * undo intersection change * use microsecond * use isclose * black * fix typing * add comments --------- Co-authored-by: Adam J. Stewart <[email protected]>
- Loading branch information
1 parent
4a92cf4
commit ceeec81
Showing
6 changed files
with
784 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,322 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
from math import floor, isclose | ||
from typing import Any, Dict, List, Sequence, Tuple, Union | ||
|
||
import pytest | ||
from rasterio.crs import CRS | ||
|
||
from torchgeo.datasets import ( | ||
BoundingBox, | ||
GeoDataset, | ||
random_bbox_assignment, | ||
random_bbox_splitting, | ||
random_grid_cell_assignment, | ||
roi_split, | ||
time_series_split, | ||
) | ||
|
||
|
||
class CustomGeoDataset(GeoDataset): | ||
def __init__( | ||
self, | ||
items: List[Tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), "")], | ||
crs: CRS = CRS.from_epsg(3005), | ||
res: float = 1, | ||
) -> None: | ||
super().__init__() | ||
for box, content in items: | ||
self.index.insert(0, tuple(box), content) | ||
self._crs = crs | ||
self.res = res | ||
|
||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: | ||
hits = self.index.intersection(tuple(query), objects=True) | ||
hit = next(iter(hits)) | ||
return {"content": hit.object} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"lengths,expected_lengths", | ||
[ | ||
# List of lengths | ||
([2, 1, 1], [2, 1, 1]), | ||
# List of fractions (with remainder) | ||
([1 / 3, 1 / 3, 1 / 3], [2, 1, 1]), | ||
], | ||
) | ||
def test_random_bbox_assignment( | ||
lengths: Sequence[Union[int, float]], expected_lengths: Sequence[int] | ||
) -> None: | ||
ds = CustomGeoDataset( | ||
[ | ||
(BoundingBox(0, 1, 0, 1, 0, 0), "a"), | ||
(BoundingBox(1, 2, 0, 1, 0, 0), "b"), | ||
(BoundingBox(2, 3, 0, 1, 0, 0), "c"), | ||
(BoundingBox(3, 4, 0, 1, 0, 0), "d"), | ||
] | ||
) | ||
|
||
train_ds, val_ds, test_ds = random_bbox_assignment(ds, lengths) | ||
|
||
# Check datasets lengths | ||
assert len(train_ds) == expected_lengths[0] | ||
assert len(val_ds) == expected_lengths[1] | ||
assert len(test_ds) == expected_lengths[2] | ||
|
||
# No overlap | ||
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) | ||
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) | ||
assert len(test_ds & train_ds) == 0 or isclose( | ||
_get_total_area(test_ds & train_ds), 0 | ||
) | ||
|
||
# Union equals original | ||
assert (train_ds | val_ds | test_ds).bounds == ds.bounds | ||
|
||
# Test __getitem__ | ||
x = train_ds[train_ds.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["content"], str) | ||
|
||
|
||
def test_random_bbox_assignment_invalid_inputs() -> None: | ||
with pytest.raises( | ||
ValueError, | ||
match="Sum of input lengths must equal 1 or the length of dataset's index.", | ||
): | ||
random_bbox_assignment(CustomGeoDataset(), lengths=[2, 2, 1]) | ||
with pytest.raises( | ||
ValueError, match="All items in input lengths must be greater than 0." | ||
): | ||
random_bbox_assignment(CustomGeoDataset(), lengths=[1 / 2, 3 / 4, -1 / 4]) | ||
|
||
|
||
def _get_total_area(dataset: GeoDataset) -> float: | ||
total_area = 0.0 | ||
for hit in dataset.index.intersection(dataset.index.bounds, objects=True): | ||
total_area += BoundingBox(*hit.bounds).area | ||
|
||
return total_area | ||
|
||
|
||
def test_random_bbox_splitting() -> None: | ||
ds = CustomGeoDataset( | ||
[ | ||
(BoundingBox(0, 1, 0, 1, 0, 0), "a"), | ||
(BoundingBox(1, 2, 0, 1, 0, 0), "b"), | ||
(BoundingBox(2, 3, 0, 1, 0, 0), "c"), | ||
(BoundingBox(3, 4, 0, 1, 0, 0), "d"), | ||
] | ||
) | ||
|
||
ds_area = _get_total_area(ds) | ||
|
||
train_ds, val_ds, test_ds = random_bbox_splitting( | ||
ds, fractions=[1 / 2, 1 / 4, 1 / 4] | ||
) | ||
train_ds_area = _get_total_area(train_ds) | ||
val_ds_area = _get_total_area(val_ds) | ||
test_ds_area = _get_total_area(test_ds) | ||
|
||
# Check datasets areas | ||
assert train_ds_area == ds_area / 2 | ||
assert val_ds_area == ds_area / 4 | ||
assert test_ds_area == ds_area / 4 | ||
|
||
# No overlap | ||
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) | ||
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) | ||
assert len(test_ds & train_ds) == 0 or isclose( | ||
_get_total_area(test_ds & train_ds), 0 | ||
) | ||
|
||
# Union equals original | ||
assert (train_ds | val_ds | test_ds).bounds == ds.bounds | ||
assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area) | ||
|
||
# Test __get_item__ | ||
x = train_ds[train_ds.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["content"], str) | ||
|
||
# Test invalid input fractions | ||
with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): | ||
random_bbox_splitting(ds, fractions=[1 / 2, 1 / 3, 1 / 4]) | ||
with pytest.raises( | ||
ValueError, match="All items in input fractions must be greater than 0." | ||
): | ||
random_bbox_splitting(ds, fractions=[1 / 2, 3 / 4, -1 / 4]) | ||
|
||
|
||
def test_random_grid_cell_assignment() -> None: | ||
ds = CustomGeoDataset( | ||
[ | ||
(BoundingBox(0, 12, 0, 12, 0, 0), "a"), | ||
(BoundingBox(12, 24, 0, 12, 0, 0), "b"), | ||
] | ||
) | ||
|
||
train_ds, val_ds, test_ds = random_grid_cell_assignment( | ||
ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=5 | ||
) | ||
|
||
# Check datasets lengths | ||
assert len(train_ds) == 1 / 2 * 2 * 5**2 + 1 | ||
assert len(val_ds) == floor(1 / 4 * 2 * 5**2) | ||
assert len(test_ds) == floor(1 / 4 * 2 * 5**2) | ||
|
||
# No overlap | ||
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) | ||
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) | ||
assert len(test_ds & train_ds) == 0 or isclose( | ||
_get_total_area(test_ds & train_ds), 0 | ||
) | ||
|
||
# Union equals original | ||
assert (train_ds | val_ds | test_ds).bounds == ds.bounds | ||
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds)) | ||
|
||
# Test __get_item__ | ||
x = train_ds[train_ds.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["content"], str) | ||
|
||
# Test invalid input fractions | ||
with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): | ||
random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 3, 1 / 4]) | ||
with pytest.raises( | ||
ValueError, match="All items in input fractions must be greater than 0." | ||
): | ||
random_grid_cell_assignment(ds, fractions=[1 / 2, 3 / 4, -1 / 4]) | ||
with pytest.raises(ValueError, match="Input grid_size must be greater than 1."): | ||
random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=1) | ||
|
||
|
||
def test_roi_split() -> None: | ||
ds = CustomGeoDataset( | ||
[ | ||
(BoundingBox(0, 1, 0, 1, 0, 0), "a"), | ||
(BoundingBox(1, 2, 0, 1, 0, 0), "b"), | ||
(BoundingBox(2, 3, 0, 1, 0, 0), "c"), | ||
(BoundingBox(3, 4, 0, 1, 0, 0), "d"), | ||
] | ||
) | ||
|
||
train_ds, val_ds, test_ds = roi_split( | ||
ds, | ||
rois=[ | ||
BoundingBox(0, 2, 0, 1, 0, 0), | ||
BoundingBox(2, 3.5, 0, 1, 0, 0), | ||
BoundingBox(3.5, 4, 0, 1, 0, 0), | ||
], | ||
) | ||
|
||
# Check datasets lengths | ||
assert len(train_ds) == 2 | ||
assert len(val_ds) == 2 | ||
assert len(test_ds) == 1 | ||
|
||
# No overlap | ||
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) | ||
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) | ||
assert len(test_ds & train_ds) == 0 or isclose( | ||
_get_total_area(test_ds & train_ds), 0 | ||
) | ||
|
||
# Union equals original | ||
assert (train_ds | val_ds | test_ds).bounds == ds.bounds | ||
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds)) | ||
|
||
# Test __get_item__ | ||
x = train_ds[train_ds.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["content"], str) | ||
|
||
# Test invalid input rois | ||
with pytest.raises(ValueError, match="ROIs in input rois can't overlap."): | ||
roi_split( | ||
ds, rois=[BoundingBox(0, 2, 0, 1, 0, 0), BoundingBox(1, 3, 0, 1, 0, 0)] | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"lengths,expected_lengths", | ||
[ | ||
# List of timestamps | ||
([(0, 20), (20, 35), (35, 40)], [2, 2, 1]), | ||
# List of lengths | ||
([20, 15, 5], [2, 2, 1]), | ||
# List of fractions (with remainder) | ||
([1 / 2, 3 / 8, 1 / 8], [2, 2, 1]), | ||
], | ||
) | ||
def test_time_series_split( | ||
lengths: Sequence[Union[Tuple[int, int], int, float]], | ||
expected_lengths: Sequence[int], | ||
) -> None: | ||
ds = CustomGeoDataset( | ||
[ | ||
(BoundingBox(0, 1, 0, 1, 0, 10), "a"), | ||
(BoundingBox(0, 1, 0, 1, 10, 20), "b"), | ||
(BoundingBox(0, 1, 0, 1, 20, 30), "c"), | ||
(BoundingBox(0, 1, 0, 1, 30, 40), "d"), | ||
] | ||
) | ||
|
||
train_ds, val_ds, test_ds = time_series_split(ds, lengths) | ||
|
||
# Check datasets lengths | ||
assert len(train_ds) == expected_lengths[0] | ||
assert len(val_ds) == expected_lengths[1] | ||
assert len(test_ds) == expected_lengths[2] | ||
|
||
# No overlap | ||
assert len(train_ds & val_ds) == 0 | ||
assert len(val_ds & test_ds) == 0 | ||
assert len(test_ds & train_ds) == 0 | ||
|
||
# Union equals original | ||
assert (train_ds | val_ds | test_ds).bounds == ds.bounds | ||
|
||
# Test __get_item__ | ||
x = train_ds[train_ds.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["content"], str) | ||
|
||
|
||
def test_time_series_split_invalid_input() -> None: | ||
with pytest.raises( | ||
ValueError, | ||
match="Pairs of timestamps in lengths must have end greater than start.", | ||
): | ||
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (35, 20), (35, 40)]) | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match="Pairs of timestamps in lengths must cover dataset's time bounds.", | ||
): | ||
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 35)]) | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match="Pairs of timestamps in lengths can't be out of dataset's time bounds.", | ||
): | ||
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 45)]) | ||
|
||
with pytest.raises( | ||
ValueError, match="Pairs of timestamps in lengths can't overlap." | ||
): | ||
time_series_split(CustomGeoDataset(), lengths=[(0, 10), (10, 20), (15, 40)]) | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match="Sum of input lengths must equal 1 or the dataset's time length.", | ||
): | ||
time_series_split(CustomGeoDataset(), lengths=[1 / 2, 1 / 2, 1 / 2]) | ||
|
||
with pytest.raises( | ||
ValueError, match="All items in input lengths must be greater than 0." | ||
): | ||
time_series_split(CustomGeoDataset(), lengths=[20, 25, -5]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.