diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index eefb9849f8..b4f58df8e8 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -369,6 +369,11 @@ PASTIS .. autoclass:: PASTIS +SubstationDataset +^^^^^^^^^^^^^^^^^ + +.. autoclass:: SubstationDataset + PatternNet ^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index 5608694036..0eb49814a3 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -49,6 +49,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `SSL4EO`_-S12,T,Sentinel-1/2,"CC-BY-4.0",1M,-,264x264,10,"SAR, MSI" `SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI `SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI +`SubstationDataset`_,S,OpenStreetMap & Sentinel-2, "CC BY-SA 2.0", 27K, 2, 228x228, 10, MSI `SustainBench Crop Yield`_,R,MODIS,"CC-BY-SA-4.0",11k,-,32x32,-,MSI `Tropical Cyclone`_,R,GOES 8--16,"CC-BY-4.0","108,110",-,256x256,4K--8K,MSI `UC Merced`_,C,USGS National Map,"public domain","2,100",21,256x256,0.3,RGB diff --git a/tests/data/substation_seg/data.py b/tests/data/substation_seg/data.py new file mode 100644 index 0000000000..356eec8385 --- /dev/null +++ b/tests/data/substation_seg/data.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np + +# Parameters +SIZE = 228 # Image dimensions +NUM_SAMPLES = 5 # Number of samples +np.random.seed(0) + +# Define directory hierarchy +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] + +filenames: FILENAME_HIERARCHY = {'image_stack': ['image'], 'mask': ['mask']} + + +def create_file(path: str, value: str) -> None: + """ + Generates .npz files for images or masks based on the path. + + Args: + - path (str): Base path for saving files (either 'image' or 'mask'). + """ + for i in range(NUM_SAMPLES): + new_path = f'{path}_{i}.npz' + + if value == 'image': + # Generate image data with shape (4, 13, SIZE, SIZE) for timepoints and channels + data = np.random.rand(4, 13, SIZE, SIZE).astype( + np.float32 + ) # 4 timepoints, 13 channels + elif value == 'mask': + # Generate mask data with shape (SIZE, SIZE) with 4 classes + data = np.random.randint(0, 4, size=(SIZE, SIZE)).astype(np.uint8) + + np.savez_compressed(new_path, arr_0=data) + + +def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: + """ + Recursively creates directory structure based on hierarchy and populates with data files. + + Args: + - directory (str): Base directory for dataset. + - hierarchy (FILENAME_HIERARCHY): Directory and file structure. + """ + if isinstance(hierarchy, dict): + # Recursive case + for key, value in hierarchy.items(): + path = os.path.join(directory, key) + os.makedirs(path, exist_ok=True) + create_directory(path, value) + else: + # Base case + for value in hierarchy: + path = os.path.join(directory, 'image') + create_file(path, value) + + +if __name__ == '__main__': + # Generate directory structure and data + create_directory('.', filenames) + + # Create zip archives of dataset folders + filename_images = 'image_stack.tar.gz' + filename_masks = 'mask.tar.gz' + shutil.make_archive('image_stack', 'gztar', '.', 'image_stack') + shutil.make_archive('mask', 'gztar', '.', 'mask') + + # Compute and print MD5 checksums for data validation + with open(filename_images, 'rb') as f: + md5_images = hashlib.md5(f.read()).hexdigest() + print(f'{filename_images}: {md5_images}') + + with open(filename_masks, 'rb') as f: + md5_masks = hashlib.md5(f.read()).hexdigest() + print(f'{filename_masks}: {md5_masks}') diff --git a/tests/data/substation_seg/image_stack.tar.gz b/tests/data/substation_seg/image_stack.tar.gz new file mode 100644 index 0000000000..23b92374da Binary files /dev/null and b/tests/data/substation_seg/image_stack.tar.gz differ diff --git a/tests/data/substation_seg/image_stack/image_0.npz b/tests/data/substation_seg/image_stack/image_0.npz new file mode 100644 index 0000000000..d460c779f8 Binary files /dev/null and b/tests/data/substation_seg/image_stack/image_0.npz differ diff --git a/tests/data/substation_seg/image_stack/image_1.npz b/tests/data/substation_seg/image_stack/image_1.npz new file mode 100644 index 0000000000..0f7e31edaa Binary files /dev/null and b/tests/data/substation_seg/image_stack/image_1.npz differ diff --git a/tests/data/substation_seg/image_stack/image_2.npz b/tests/data/substation_seg/image_stack/image_2.npz new file mode 100644 index 0000000000..4c3504be0e Binary files /dev/null and b/tests/data/substation_seg/image_stack/image_2.npz differ diff --git a/tests/data/substation_seg/image_stack/image_3.npz b/tests/data/substation_seg/image_stack/image_3.npz new file mode 100644 index 0000000000..0104c26731 Binary files /dev/null and b/tests/data/substation_seg/image_stack/image_3.npz differ diff --git a/tests/data/substation_seg/image_stack/image_4.npz b/tests/data/substation_seg/image_stack/image_4.npz new file mode 100644 index 0000000000..1adf8f7c3e Binary files /dev/null and b/tests/data/substation_seg/image_stack/image_4.npz differ diff --git a/tests/data/substation_seg/mask.tar.gz b/tests/data/substation_seg/mask.tar.gz new file mode 100644 index 0000000000..887debae63 Binary files /dev/null and b/tests/data/substation_seg/mask.tar.gz differ diff --git a/tests/data/substation_seg/mask/image_0.npz b/tests/data/substation_seg/mask/image_0.npz new file mode 100644 index 0000000000..1559f93353 Binary files /dev/null and b/tests/data/substation_seg/mask/image_0.npz differ diff --git a/tests/data/substation_seg/mask/image_1.npz b/tests/data/substation_seg/mask/image_1.npz new file mode 100644 index 0000000000..56a1e5cc97 Binary files /dev/null and b/tests/data/substation_seg/mask/image_1.npz differ diff --git a/tests/data/substation_seg/mask/image_2.npz b/tests/data/substation_seg/mask/image_2.npz new file mode 100644 index 0000000000..9d0094bbff Binary files /dev/null and b/tests/data/substation_seg/mask/image_2.npz differ diff --git a/tests/data/substation_seg/mask/image_3.npz b/tests/data/substation_seg/mask/image_3.npz new file mode 100644 index 0000000000..3011ce9dd2 Binary files /dev/null and b/tests/data/substation_seg/mask/image_3.npz differ diff --git a/tests/data/substation_seg/mask/image_4.npz b/tests/data/substation_seg/mask/image_4.npz new file mode 100644 index 0000000000..e161f9b972 Binary files /dev/null and b/tests/data/substation_seg/mask/image_4.npz differ diff --git a/tests/datasets/test_substation_seg.py b/tests/datasets/test_substation_seg.py new file mode 100644 index 0000000000..c216934bb3 --- /dev/null +++ b/tests/datasets/test_substation_seg.py @@ -0,0 +1,267 @@ +import os +import shutil +from collections.abc import Generator +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import matplotlib.pyplot as plt +import numpy as np +import pytest +import torch +import torchvision.transforms as transforms + +from torchgeo.datasets import SubstationDataset + + +class Args: + """Mocked arguments for testing SubstationDataset.""" + + def __init__(self) -> None: + self.data_dir: str = os.path.join(os.getcwd(), 'tests', 'data') + self.in_channels: int = 13 + self.use_timepoints: bool = True + self.normalizing_type: str = 'percentile' + self.mask_2d: bool = True + self.model_type: str = 'vanilla_unet' + self.timepoint_aggregation: str = 'median' + self.color_transforms: bool = False + self.geo_transforms: bool = False + self.normalizing_factor: Any = np.array([[0, 0.5, 1.0]], dtype=np.float32) + self.means: Any = np.array( + [ + [[1431]], + [[1233]], + [[1209]], + [[1192]], + [[1448]], + [[2238]], + [[2609]], + [[2537]], + [[2828]], + [[884]], + [[20]], + [[2226]], + [[1537]], + ], + dtype=np.float32, + ) + self.stds: Any = np.array( + [ + [[157]], + [[254]], + [[290]], + [[420]], + [[363]], + [[457]], + [[575]], + [[606]], + [[630]], + [[156]], + [[3]], + [[554]], + [[523]], + ], + dtype=np.float32, + ) + + +@pytest.fixture +def dataset( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> Generator[SubstationDataset, None, None]: + """Fixture for the SubstationDataset.""" + args = Args() + image_files = ['image_0.npz', 'image_1.npz'] + + yield SubstationDataset(args, image_files) + + +@pytest.mark.parametrize( + 'config', + [ + { + 'normalizing_type': 'percentile', + 'in_channels': 3, + 'use_timepoints': False, + 'mask_2d': True, + }, + { + 'normalizing_type': 'zscore', + 'in_channels': 9, + 'model_type': 'swin', + 'use_timepoints': True, + 'timepoint_aggregation': 'concat', + 'mask_2d': False, + }, + { + 'normalizing_type': None, + 'in_channels': 12, + 'use_timepoints': True, + 'timepoint_aggregation': 'median', + 'mask_2d': True, + 'normalizing_factor': 1.0, + }, + { + 'normalizing_type': None, + 'in_channels': 5, + 'use_timepoints': True, + 'timepoint_aggregation': 'first', + 'mask_2d': False, + 'normalizing_factor': 1.0, + }, + { + 'normalizing_type': None, + 'in_channels': 4, + 'use_timepoints': True, + 'timepoint_aggregation': 'random', + 'mask_2d': True, + 'normalizing_factor': 1.0, + }, + { + 'normalizing_type': 'zscore', + 'in_channels': 2, + 'use_timepoints': False, + 'mask_2d': False, + 'color_transforms': True, + 'geo_transforms': True, + }, + { + 'normalizing_type': None, + 'in_channels': 5, + 'use_timepoints': False, + 'timepoint_aggregation': 'first', + 'mask_2d': False, + 'normalizing_factor': 1.0, + }, + { + 'normalizing_type': None, + 'in_channels': 4, + 'use_timepoints': False, + 'timepoint_aggregation': 'random', + 'mask_2d': True, + 'normalizing_factor': 1.0, + }, + ], +) +def test_getitem_semantic(config: dict[str, Any]) -> None: + args = Args() + for key, value in config.items(): + setattr(args, key, value) # Dynamically set arguments for each config + + # Setting mock paths and creating dataset instance + image_files = ['image_0.npz', 'image_1.npz'] + image_resize = transforms.Compose( + [transforms.Resize(228, transforms.InterpolationMode.BICUBIC)] + ) + mask_resize = transforms.Compose( + [transforms.Resize(228, transforms.InterpolationMode.NEAREST)] + ) + dataset = SubstationDataset( + args, image_files, image_resize=image_resize, mask_resize=mask_resize + ) + + x = dataset[0] + assert isinstance(x, dict), f'Expected dict, got {type(x)}' + assert isinstance(x['image'], torch.Tensor), 'Expected image to be a torch.Tensor' + assert isinstance(x['mask'], torch.Tensor), 'Expected mask to be a torch.Tensor' + + +def test_len(dataset: SubstationDataset) -> None: + """Test the length of the dataset.""" + assert len(dataset) == 2 + + +def test_output_shape(dataset: SubstationDataset) -> None: + """Test the output shape of the dataset.""" + x = dataset[0] + assert x['image'].shape == torch.Size([13, 228, 228]) + assert x['mask'].shape == torch.Size([2, 228, 228]) + + +def test_plot(dataset: SubstationDataset, monkeypatch: pytest.MonkeyPatch) -> None: + """Test the plot method of the dataset.""" + # Mock plt.show to avoid showing the plot during the test + mock_show = MagicMock() + monkeypatch.setattr(plt, 'show', mock_show) + + # Mock np.random.randint to return a fixed index (e.g., 0) + monkeypatch.setattr( + np.random, 'randint', lambda low, high: 0 + ) # Correct the lambda to accept 2 arguments + + # Mock __getitem__ to return a sample with an image (3 channels) and a mask + mock_image = torch.rand(3, 228, 228) # Create a dummy 3-channel image (RGB) + mock_mask = torch.randint(0, 4, (228, 228)) # Create a dummy mask + monkeypatch.setattr( + dataset, '__getitem__', lambda idx: {'image': mock_image, 'mask': mock_mask} + ) + + # Call the plot method + dataset.plot() + + +def test_already_downloaded( + dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test that the dataset doesn't re-download if already present.""" + # Simulating that files are already present by copying them to the target directory + url_for_images = os.path.join( + 'tests', 'data', 'substation_seg', 'image_stack.tar.gz' + ) + url_for_masks = os.path.join('tests', 'data', 'substation_seg', 'mask.tar.gz') + + # Copy files to the temporary directory to simulate already downloaded files + shutil.copy(url_for_images, tmp_path) + shutil.copy(url_for_masks, tmp_path) + + # No download should be attempted, since the files are already present + # Mock the _download method to simulate the behavior + monkeypatch.setattr(dataset, '_download', MagicMock()) + dataset._download() # This will now call the mocked method + + +def test_verify(dataset: SubstationDataset, monkeypatch: pytest.MonkeyPatch) -> None: + """Test the _verify method of the dataset.""" + # Mock os.path.exists to return False for the image and mask directories + monkeypatch.setattr(os.path, 'exists', lambda path: False) + + # Mock the _download method to avoid actually downloading the dataset + mock_download = MagicMock() + monkeypatch.setattr(dataset, '_download', mock_download) + + # Call the _verify method + dataset._verify() + + # Check that the _download method was called + mock_download.assert_called_once() + + +def test_download( + dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test the _download method of the dataset.""" + # Mock the download_url and extract_archive functions + mock_download_url = MagicMock() + mock_extract_archive = MagicMock() + monkeypatch.setattr( + 'torchgeo.datasets.substation_seg.download_url', mock_download_url + ) + monkeypatch.setattr( + 'torchgeo.datasets.substation_seg.extract_archive', mock_extract_archive + ) + + # Call the _download method + dataset._download() + + # Check that download_url was called twice + mock_download_url.assert_called() + assert mock_download_url.call_count == 2 + + # Check that extract_archive was called twice + mock_extract_archive.assert_called() + assert mock_extract_archive.call_count == 2 + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 663b08e7cb..3cb62c41c4 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -128,6 +128,7 @@ ) from .ssl4eo import SSL4EO, SSL4EOL, SSL4EOS12 from .ssl4eo_benchmark import SSL4EOLBenchmark +from .substation_seg import SubstationDataset from .sustainbench_crop_yield import SustainBenchCropYield from .ucmerced import UCMerced from .usavars import USAVars @@ -152,6 +153,7 @@ 'AsterGDEM', 'CanadianBuildingFootprints', 'CDL', + 'ChaBuDx', 'Chesapeake', 'Chesapeake7', 'Chesapeake13', @@ -263,6 +265,7 @@ 'SSL4EOLBenchmark', 'SSL4EOL', 'SSL4EOS12', + 'SubstationDataset', 'SustainBenchCropYield', 'TropicalCyclone', 'UCMerced', diff --git a/torchgeo/datasets/substation_seg.py b/torchgeo/datasets/substation_seg.py new file mode 100644 index 0000000000..a5ce66a0fe --- /dev/null +++ b/torchgeo/datasets/substation_seg.py @@ -0,0 +1,170 @@ +"""This module handles the Substation segmentation dataset.""" + +import os +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch import Tensor + +from .geo import NonGeoDataset +from .utils import download_url, extract_archive + + +class SubstationDataset(NonGeoDataset): + """SubstationDataset is responsible for handling the loading and transformation of substation segmentation datasets. + + It extends NonGeoDataset, providing methods for dataset verification, downloading, and transformation. + """ + + directory: str = 'Substation' + filename_images: str = 'image_stack.tar.gz' + filename_masks: str = 'mask.tar.gz' + url_for_images: str = 'https://urldefense.proofpoint.com/v2/url?u=https-3A__storage.googleapis.com_tz-2Dml-2Dpublic_substation-2Dover-2D10km2-2Dcsv-2Dmain-2D444e360fd2b6444b9018d509d0e4f36e_image-5Fstack.tar.gz&d=DwMFaQ&c=slrrB7dE8n7gBJbeO0g-IQ&r=ypwhORbsf5rB8FTl-SAxjfN_U0jrVqx6UDyBtJHbKQY&m=-2QXCp-gZof5HwBsLg7VwQD-pnLedAo09YCzdDCUTqCI-0t789z0-HhhgwVbYtX7&s=zMCjuqjPMHRz5jeEWLCEufHvWxRPdlHEbPnUE7kXPrc&e=' + url_for_masks: str = 'https://urldefense.proofpoint.com/v2/url?u=https-3A__storage.googleapis.com_tz-2Dml-2Dpublic_substation-2Dover-2D10km2-2Dcsv-2Dmain-2D444e360fd2b6444b9018d509d0e4f36e_mask.tar.gz&d=DwMFaQ&c=slrrB7dE8n7gBJbeO0g-IQ&r=ypwhORbsf5rB8FTl-SAxjfN_U0jrVqx6UDyBtJHbKQY&m=-2QXCp-gZof5HwBsLg7VwQD-pnLedAo09YCzdDCUTqCI-0t789z0-HhhgwVbYtX7&s=nHMdYvxKmzwAdT2lOPoQ7-NEfjsOjAm00kHOcwC_AmU&e=' + + def __init__( + self, + args: Any, + image_files: list[str], + geo_transforms: Any | None = None, + color_transforms: Any | None = None, + image_resize: Any | None = None, + mask_resize: Any | None = None, + ) -> None: + """Initialize the SubstationDataset. + + Args: + args (Any): Arguments containing various dataset parameters such as `data_dir`, `in_channels`, etc. + image_files (list[str]): A list of image file names. + geo_transforms (Any | None): Geometric transformations to be applied to the images and masks. Defaults to None. + color_transforms (Any | None): Color transformations to be applied to the images. Defaults to None. + image_resize (Any | None): Transformation for resizing the images. Defaults to None. + mask_resize (Any | None): Transformation for resizing the masks. Defaults to None. + """ + self.data_dir = args.data_dir + self.geo_transforms = geo_transforms + self.color_transforms = color_transforms + self.image_resize = image_resize + self.mask_resize = mask_resize + self.in_channels = args.in_channels + self.use_timepoints = args.use_timepoints + self.normalizing_type = args.normalizing_type + self.normalizing_factor = args.normalizing_factor + self.mask_2d = args.mask_2d + self.model_type = args.model_type + self.image_dir = os.path.join(args.data_dir, 'substation_seg', 'image_stack') + self.mask_dir = os.path.join(args.data_dir, 'substation_seg', 'mask') + self.image_filenames = image_files + self.args = args + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Get an item from the dataset by index. + + Args: + index: Index of the item to retrieve. + + Returns: + A dictionary containing the image and corresponding mask. + """ + image_filename = self.image_filenames[index] + image_path = os.path.join(self.image_dir, image_filename) + mask_path = os.path.join(self.mask_dir, image_filename) + + image = np.load(image_path)['arr_0'] + # standardizing image + if self.normalizing_type == 'percentile': + image = ( + image - self.normalizing_factor[:, 0].reshape((-1, 1, 1)) + ) / self.normalizing_factor[:, 2].reshape((-1, 1, 1)) + elif self.normalizing_type == 'zscore': + # means = np.array([1431, 1233, 1209, 1192, 1448, 2238, 2609, 2537, 2828, 884, 20, 2226, 1537]).reshape(-1, 1, 1) + # stds = np.array([157, 254, 290, 420, 363, 457, 575, 606, 630, 156, 3, 554, 523]).reshape(-1, 1, 1) + image = (image - self.args.means) / self.args.stds + else: + image = image / self.normalizing_factor + # clipping image to 0,1 range + image = np.clip(image, 0, 1) + + # selecting channels + if self.in_channels == 3: + image = image[:, [3, 2, 1], :, :] + else: + if self.model_type == 'swin': + image = image[ + :, [3, 2, 1, 4, 5, 6, 7, 10, 11], :, : + ] # swin only takes 9 channels + else: + image = image[:, : self.in_channels, :, :] + + # handling multiple images across timepoints + if self.use_timepoints: + image = image[:4, :, :, :] + if self.args.timepoint_aggregation == 'concat': + image = np.reshape( + image, (-1, image.shape[2], image.shape[3]) + ) # (4*channels,h,w) + elif self.args.timepoint_aggregation == 'median': + image = np.median(image, axis=0) + else: + # image = np.median(image, axis=0) + # image = image[0] + if self.args.timepoint_aggregation == 'first': + image = image[0] + elif self.args.timepoint_aggregation == 'random': + image = image[np.random.randint(image.shape[0])] + + mask = np.load(mask_path)['arr_0'] + mask[mask != 3] = 0 + mask[mask == 3] = 1 + + image = torch.from_numpy(image) + mask = torch.from_numpy(mask).float() + mask = mask.unsqueeze(dim=0) + + if self.mask_2d: + mask_0 = 1.0 - mask + mask = torch.concat([mask_0, mask], dim=0) + + if self.image_resize: + image = self.image_resize(image) + + if self.mask_resize: + mask = self.mask_resize(mask) + + return {'image': image, 'mask': mask} + + def __len__(self) -> int: + """Returns the number of items in the dataset.""" + return len(self.image_filenames) + + def plot(self) -> None: + """Plots a random image and mask from the dataset.""" + index = np.random.randint(0, self.__len__()) + data = self.__getitem__(index) + image = data['image'] + mask = data['mask'] + + fig, axs = plt.subplots(1, 2, figsize=(15, 15)) + axs[0].imshow(image.permute(1, 2, 0).cpu().numpy()) + axs[1].imshow(image.permute(1, 2, 0).cpu().numpy()) + axs[1].imshow(mask.squeeze().cpu().numpy(), alpha=0.5, cmap='gray') + + def _verify(self) -> None: + """Checks if dataset exists, otherwise download it.""" + image_dir_exists = os.path.exists(self.image_dir) + mask_dir_exists = os.path.exists(self.mask_dir) + if not (image_dir_exists and mask_dir_exists): + self._download() + + def _download(self) -> None: + """Download the dataset.""" + # Assuming self.url_for_images and self.url_for_masks are URLs for dataset components + download_url(self.url_for_images, self.data_dir, filename=self.filename_images) + extract_archive( + os.path.join(self.data_dir, self.filename_images), self.data_dir + ) + + download_url(self.url_for_masks, self.data_dir, filename=self.filename_masks) + extract_archive(os.path.join(self.data_dir, self.filename_masks), self.data_dir)